2P by GN⁺ | ★ favorite | 댓글 1개

Felafax BlogTune Llama3 405B on AMD MI300x (우리의 여정)

소개

  • 오픈 소스 모델이 커지면서 대규모 AI 훈련을 처리할 강력한 인프라의 필요성이 커짐
  • Felafax는 AMD GPU에서 LLaMA 3.1 405B 모델을 미세 조정하여 AMD 하드웨어의 효율성을 입증함
  • 모든 작업을 GitHub에 오픈 소스로 공개함
  • AMD MI300X GPU는 NVIDIA AI 하드웨어에 비해 높은 성능을 제공함
  • TensorWave의 지원으로 프로젝트가 가능했음

JAX란 무엇이며 왜 선택했는가

  • JAX는 NumPy와 유사한 API, 자동 미분, Google's XLA 컴파일러를 결합한 강력한 머신러닝 라이브러리임
  • 모델 병렬 처리를 위한 우수한 API를 제공하여 대규모 모델 훈련에 이상적임

JAX의 장점

  • 순수 함수: JAX는 순수 함수를 작성하도록 권장하여 코드의 구성, 디버깅, 읽기가 쉬워짐
  • 고급 병렬 처리: JAX의 유연한 JIT API는 대규모 훈련에 필수적인 고급 데이터 및 모델 병렬 처리를 지원함
  • 깨끗한 코드베이스: JAX의 설계 철학은 하드웨어 플랫폼 간에 이식 가능한 코드를 작성하도록 장려함

JAX가 비-NVIDIA 하드웨어에서 뛰어난 이유

  • 하드웨어 독립적 접근: JAX는 XLA 컴파일러를 활용하여 하드웨어 독립적인 중간 표현으로 계산을 컴파일함
  • 플랫폼 독립적 최적화: XLA 컴파일러는 하드웨어와 독립적으로 최적화를 수행함
  • 간편한 이식성: JAX를 사용하면 NVIDIA에서 AMD로 전환할 때 코드 변경이 최소화됨

AMD GPU에서 JAX 설정

  • Docker 이미지를 가져오고 컨테이너를 시작한 후 설치를 확인함
  • AMD MI300x GPU 8개를 사용하여 LLaMA 405B 모델을 훈련함

LLaMA 405B 훈련: 성능 및 확장성

  • JAX를 사용하여 AMD GPU에서 LLaMA 405B 모델을 훈련함
  • LoRA 미세 조정을 통해 모델 가중치와 LoRA 매개변수를 bfloat16 정밀도로 조정함
  • 모델 크기: 약 800GB의 VRAM을 차지함
  • LoRA 가중치 및 옵티마이저 상태: 약 400GB의 VRAM을 차지함
  • 총 VRAM 사용량: 약 1200GB
  • 훈련 속도: 초당 약 35 토큰
  • 메모리 효율성: 약 70% 유지
  • 확장성: JAX를 사용하여 8개의 GPU에서 거의 선형적으로 확장됨

우리의 훈련 설정

  • LLaMA 3.1을 PyTorch에서 JAX로 변환함
  • 모델 로딩 및 매개변수 샤딩을 통해 효율적으로 분산함

JAX에서 매개변수 샤딩

  • JAX의 디바이스 메쉬 기능을 사용하여 8개의 AMD GPU에 모델을 효율적으로 분산함
  • 매개변수 샤딩 규칙을 정의하여 각 텐서의 차원을 메쉬 축에 따라 샤딩함

LoRA 훈련 구현

  • LoRA는 가중치 업데이트를 저랭크 행렬로 분해하여 훈련 가능한 매개변수 수를 줄임
  • LoRADense 레이어를 구현하여 LoRA 매개변수를 포함함
  • LoRA 매개변수를 효율적으로 분산하여 메모리 사용량과 계산 효율성을 최적화함

결론

  • AMD GPU와 JAX를 사용하여 LLaMA 3.1 405B 모델을 미세 조정하는 경험이 매우 긍정적이었음
  • JAX의 강력한 병렬 처리 기능과 하드웨어 독립적 접근 방식을 활용하여 모델을 효율적으로 분산함
  • AMD GPU가 대규모 AI 훈련을 위한 강력한 대안임을 입증함
  • GitHub 저장소에서 전체 코드를 확인하고 직접 실행할 수 있음

GN⁺의 정리

  • 이 기사는 AMD GPU와 JAX를 사용하여 대규모 AI 모델을 효율적으로 훈련하는 방법을 설명함
  • AMD 하드웨어가 NVIDIA에 비해 비용 효율적인 대안임을 강조함
  • JAX의 하드웨어 독립적 접근 방식이 코드 이식성을 높이고 유지보수를 용이하게 함
  • 대규모 모델 훈련에 관심 있는 사람들에게 유용한 정보와 실습 코드를 제공함
  • 유사한 기능을 가진 프로젝트로는 NVIDIA의 CUDA와 PyTorch가 있음

댓글과 토론

Hacker News 의견들
  • 최근 8xAMD MI300x GPU에서 PyTorch 대신 JAX로 llama3.1 405B 모델을 미세조정했음
    JAX의 고급 샤딩 API 덕분에 좋은 성능을 냈고, 사용한 샤딩 기법은 블로그에 정리했음. 코드도 공개함: https://github.com/felafax/felafax
    우리는 NVIDIA가 아닌 하드웨어(TPU, AMD, Trainium)에서 LLM 미세조정과 서빙을 위한 AI 인프라를 만드는 작은 스타트업임
    많은 회사가 AMD GPU에서 PyTorch를 돌리려 하지만, PyTorch는 torch.cudascaled_dot_product_attention처럼 NVIDIA 생태계와 깊게 얽혀 있어 “탈-NVIDIA화”가 많이 필요하다고 봄
    JAX는 모델 코드가 하드웨어 독립적인 HLO 그래프로 컴파일되고, 이후 XLA 컴파일러가 최적화한 뒤 하드웨어별 최적화를 적용하므로 NVIDIA가 아닌 하드웨어에 더 잘 맞는다고 생각함. 같은 LLaMA3 JAX 코드가 Google TPU와 AMD GPU에서 수정 없이 동작했음
    회사 전략은 모델을 먼저 JAX로 포팅한 뒤, JAX 프레임워크와 XLA 커널을 활용해 NVIDIA가 아닌 백엔드에서 최대 성능을 끌어내는 것임. 그래서 Llama 3.1을 PyTorch에서 JAX로 먼저 옮겼고, 같은 JAX 모델이 TPU와 AMD GPU에서 잘 동작함

    • AMD GPU에서 PyTorch를 CUDA 코드 변경 없이 돌리는 데 별문제가 없었음. MosaicML 블로그도 참고할 만함: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • Llama 3.1의 JAX 포팅 정확도는 어떻게 검증하고 있는지 궁금함
      개인적으로 PyTorch를 쓰는 주된 이유는 원본 모델이 PyTorch로 만들어졌기 때문임. 서로 다른 모델 버전에서 로직이 같아 보여도, 엄청난 데이터 규모에서는 아주 작은 부동소수점 오차가 누적되어 모델 드리프트가 생길 수 있음
      큰 모델에서 이런 정확도 불일치를 디버깅하는 건 지옥의 10번째 원보다 더 괴로운 일에 가까웠음
    • JAX가 행렬 곱셈이나 FlashAttention을 자체 구현으로 갖고 있는지, 아니면 PyTorch처럼 ROCm 구현을 쓰는지 궁금함. 예를 들면 hipblaslt, Composable Kernel FA 같은 것들임
      JAX를 잘 알지는 못하지만, MI300x에서 PyTorch 학습 성능이 처참한 이유의 상당 부분은 내부에서 쓰는 ROCm 라이브러리 성능이 느리기 때문이라고 봄
    • 7900 XTX 같은 소비자용 카드에서도 동작하는지 궁금함
      여기서 동작한다는 건 드라이버 잡느라 2주를 보내고 나서 서버를 다시는 업데이트하지 못하는 상태를 말하는 게 아님
    • 마이그레이션이라면 같은 모델을 PyTorch 버전과 비교한 실제 수치가 있는지 궁금함. 글의 비교표는 기술적인 측면에 가까워 보임
      마주친 기술적 이슈도 궁금함
  • 분명히 말하면 이 성능은 꽤 나쁨. 아마 컴파일을 제대로 동작시키지 못한 탓으로 보임
    405B 모델에서 35토큰/초가 나오는데, 이는 약 85테라플롭스에 해당함. 8개의 MI300x GPU는 10.4페타플롭스 수준이므로 MFU가 약 0.8%임
    decent한 학습 성능인 30~40% MFU보다 40~50배 낮은 수치라서, AMD 입장에서는 소프트웨어 스택이 병목이기를 바랄 듯함

    • 나도 정확히 그걸 묻고 싶었음
      GitHub 페이지에서는 “Google Cloud TPU에서 LLaMa3.1을 30% 낮은 비용으로 튜닝할 수 있다”고 하지만, 성능은 언급하지 않음
  • 훌륭한 작업임. 1년 전쯤 AMD GPU와 ROCm 지원을 조금 만져봤는데, AMD가 Nvidia를 따라잡으려면 아직 갈 길이 멀다는 게 분명했음
    JAX를 선택한 접근은 흥미로운데, 기계학습 표준 라이브러리에 가까운 PyTorch에서 벗어나면서 어떤 어려움이 있었는지 궁금함

    • 몇 주 전에 우리의 여정을 설명한 Show HN을 올렸음: https://news.ycombinator.com/item?id=41512142
      처음에는 TPU에서 LLaMA 3를 미세조정하는 것이 목표였지만, PyTorch XLA가 투박해서 모델을 JAX로 다시 작성하기로 했음
      앞서 말했듯이 JAX가 NVIDIA가 아닌 GPU에 더 나은 플랫폼이라고 보고, JAX+openXLA 위에서 NVIDIA가 아닌 GPU용 인프라를 만들고 싶음
    • Debian 12 시스템에서 AMD ROCm을 동작시키지 못하고 있고, 그래서 Ollama가 GPU 대신 CPU를 쓰는 것 같음. 아직 갈 길이 멀어 보임
  • 좋은 작업임. 지난 주말에 나도 405B의 추론 쪽을 만져보고 있었음 [0]
    torch.cuda가 그렇게 나쁘다는 데는 확신이 안 섬. AMD용 PyTorch가 그걸 대신 변환해 주기 때문임. 본질적인 문제라기보다는 이름 문제에 가까움
    실제로 rocm:pytorch 컨테이너를 가져오는 것도 rocm:jax 컨테이너를 가져오는 것만큼 쉬움
    게시된 수치가 많지 않은데, MFU는 얼마가 나왔는지 궁금함
    [0] https://x.com/HotAisle/status/1837580046732874026

    • 좋음
      MFU는 계산해야 함. GPU와 VRAM 세부 정보는 저장소에서 볼 수 있음: https://dub.sh/amd-405b-res
      다음 주말에 학습 실행을 다시 시도하면서 전체 학습 단계를 JIT 컴파일하고, 그때 MFU를 계산할 계획임
  • 우리가 ZML에서 측정했을 때 MI300X는 H100보다 30% 빨랐음. 훌륭한 칩들임

  • 8xAMD MI300 호스트를 빌릴 수 있는 클라우드 제공자가 있는지 궁금함
    업무상 AWS를 많이 쓰는데 AMD GPU를 한번 써보고 싶었음

    • 참고로 우리 회사가 8xMI300x를 임대하고 있으니 연락해도 됨
    • Oracle은 제공함. 다른 곳들도 따라올 가능성이 크지만, 작은 업체들이 상대하기는 더 합리적일 거라고 봄
  • 성능 데이터는 어디에 있나?

    • GitHub 저장소에 GPU와 VRAM 사용률 데이터를 추가했음: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      코드와 VRAM 제약 때문에 405B 모델의 JIT 컴파일 버전은 실행하지 못했음. 이 부분은 더 조사해야 함
      전체 학습 실행은 JAX 즉시 실행 모드로 수행했기 때문에 성능 개선 여지가 큼
      즉시 실행 모드에서도 GPU 사용률이 전반적으로 약 30~40%였고, 꽤 괜찮은 편임. JIT를 쓰면 GPU 사용률이 50~60%까지 쉽게 올라갈 수 있다고 봄
  • 가능하다면 메모리 제약을 극복해서 JIT 컴파일 버전을 실행하는 방법을 탐색하면 흥미로울 듯함. 추가 성능 개선으로 이어질 수 있음

    • 동의함. 아직 끌어낼 성능이 많이 남아 있음
      JIT 컴파일된 학습 단계, 더 최적화된 데이터 로딩과 샤딩, 그래디언트 누적, 활성화 체크포인팅이 필요함
      계속 만들고 개선 사항을 모두 구현한 뒤 곧 다시 블로그를 올릴 예정임
  • AMD가 GPU 대량 주문과 공급 부족을 통해 여기서 가치를 뽑아내는 데 조금이라도 가까워졌는지 궁금함
    내 인상은 “아니다”에 가까움

    • 비꼬는 건 알겠음. 하지만 지금 시점에서 AI의 하드웨어와 소프트웨어를 단일 공급원에 전부 맡길 생각이 아니라면, 대안을 향해 움직이기 시작해야 함
      상대는 엄청난 선행 우위를 갖고 있고, 소프트웨어 쪽에서 할 일이 분명히 많음. 시간이 필요함
  • 왜 노트 앱인 Obsidian이 이걸 하고 있나?

    • 그런 게 아님. 이 회사가 문서 게시에 Obsidian Publish를 쓰고 있는 것임