2P by neo 19일전 | favorite | 댓글 1개

Felafax BlogTune Llama3 405B on AMD MI300x (우리의 여정)

소개

  • 오픈 소스 모델이 커지면서 대규모 AI 훈련을 처리할 강력한 인프라의 필요성이 커짐
  • Felafax는 AMD GPU에서 LLaMA 3.1 405B 모델을 미세 조정하여 AMD 하드웨어의 효율성을 입증함
  • 모든 작업을 GitHub에 오픈 소스로 공개함
  • AMD MI300X GPU는 NVIDIA AI 하드웨어에 비해 높은 성능을 제공함
  • TensorWave의 지원으로 프로젝트가 가능했음

JAX란 무엇이며 왜 선택했는가

  • JAX는 NumPy와 유사한 API, 자동 미분, Google's XLA 컴파일러를 결합한 강력한 머신러닝 라이브러리임
  • 모델 병렬 처리를 위한 우수한 API를 제공하여 대규모 모델 훈련에 이상적임

JAX의 장점

  • 순수 함수: JAX는 순수 함수를 작성하도록 권장하여 코드의 구성, 디버깅, 읽기가 쉬워짐
  • 고급 병렬 처리: JAX의 유연한 JIT API는 대규모 훈련에 필수적인 고급 데이터 및 모델 병렬 처리를 지원함
  • 깨끗한 코드베이스: JAX의 설계 철학은 하드웨어 플랫폼 간에 이식 가능한 코드를 작성하도록 장려함

JAX가 비-NVIDIA 하드웨어에서 뛰어난 이유

  • 하드웨어 독립적 접근: JAX는 XLA 컴파일러를 활용하여 하드웨어 독립적인 중간 표현으로 계산을 컴파일함
  • 플랫폼 독립적 최적화: XLA 컴파일러는 하드웨어와 독립적으로 최적화를 수행함
  • 간편한 이식성: JAX를 사용하면 NVIDIA에서 AMD로 전환할 때 코드 변경이 최소화됨

AMD GPU에서 JAX 설정

  • Docker 이미지를 가져오고 컨테이너를 시작한 후 설치를 확인함
  • AMD MI300x GPU 8개를 사용하여 LLaMA 405B 모델을 훈련함

LLaMA 405B 훈련: 성능 및 확장성

  • JAX를 사용하여 AMD GPU에서 LLaMA 405B 모델을 훈련함
  • LoRA 미세 조정을 통해 모델 가중치와 LoRA 매개변수를 bfloat16 정밀도로 조정함
  • 모델 크기: 약 800GB의 VRAM을 차지함
  • LoRA 가중치 및 옵티마이저 상태: 약 400GB의 VRAM을 차지함
  • 총 VRAM 사용량: 약 1200GB
  • 훈련 속도: 초당 약 35 토큰
  • 메모리 효율성: 약 70% 유지
  • 확장성: JAX를 사용하여 8개의 GPU에서 거의 선형적으로 확장됨

우리의 훈련 설정

  • LLaMA 3.1을 PyTorch에서 JAX로 변환함
  • 모델 로딩 및 매개변수 샤딩을 통해 효율적으로 분산함

JAX에서 매개변수 샤딩

  • JAX의 디바이스 메쉬 기능을 사용하여 8개의 AMD GPU에 모델을 효율적으로 분산함
  • 매개변수 샤딩 규칙을 정의하여 각 텐서의 차원을 메쉬 축에 따라 샤딩함

LoRA 훈련 구현

  • LoRA는 가중치 업데이트를 저랭크 행렬로 분해하여 훈련 가능한 매개변수 수를 줄임
  • LoRADense 레이어를 구현하여 LoRA 매개변수를 포함함
  • LoRA 매개변수를 효율적으로 분산하여 메모리 사용량과 계산 효율성을 최적화함

결론

  • AMD GPU와 JAX를 사용하여 LLaMA 3.1 405B 모델을 미세 조정하는 경험이 매우 긍정적이었음
  • JAX의 강력한 병렬 처리 기능과 하드웨어 독립적 접근 방식을 활용하여 모델을 효율적으로 분산함
  • AMD GPU가 대규모 AI 훈련을 위한 강력한 대안임을 입증함
  • GitHub 저장소에서 전체 코드를 확인하고 직접 실행할 수 있음

GN⁺의 정리

  • 이 기사는 AMD GPU와 JAX를 사용하여 대규모 AI 모델을 효율적으로 훈련하는 방법을 설명함
  • AMD 하드웨어가 NVIDIA에 비해 비용 효율적인 대안임을 강조함
  • JAX의 하드웨어 독립적 접근 방식이 코드 이식성을 높이고 유지보수를 용이하게 함
  • 대규모 모델 훈련에 관심 있는 사람들에게 유용한 정보와 실습 코드를 제공함
  • 유사한 기능을 가진 프로젝트로는 NVIDIA의 CUDA와 PyTorch가 있음
Hacker News 의견
  • JAX를 사용하여 Llama3.1 405B 모델을 8xAMD MI300x GPU에서 미세 조정한 성과 공유

    • JAX의 고급 샤딩 API 덕분에 뛰어난 성능을 달성함
    • 블로그 포스트와 오픈 소스 코드 링크 제공: GitHub 링크
    • NVIDIA 하드웨어가 아닌 TPU, AMD, Trainium에서 LLM을 미세 조정하고 서비스하는 AI 인프라를 구축하는 스타트업임
    • 많은 회사들이 AMD GPU에서 PyTorch를 작동시키려고 하지만, 이는 어려운 길이라고 판단함
    • PyTorch는 NVIDIA 생태계와 깊이 연관되어 있어 비-NVIDIA 하드웨어에서 작동시키려면 많은 수정이 필요함
    • JAX는 비-NVIDIA 하드웨어에 더 적합하다고 믿음
    • JAX에서는 ML 모델 코드가 하드웨어 독립적인 HLO 그래프로 컴파일되고, XLA 컴파일러가 하드웨어 특정 최적화를 수행함
    • 동일한 JAX 코드를 Google TPU와 AMD GPU에서 변경 없이 실행 가능함
    • 회사 전략은 JAX로 모델을 포팅하고, XLA 커널을 활용해 비-NVIDIA 백엔드에서 최대 성능을 추출하는 것임
    • Llama 3.1을 PyTorch에서 JAX로 처음 포팅했으며, 이제 동일한 JAX 모델이 TPU와 AMD GPU에서 잘 작동함
    • 비전과 저장소에 대한 의견을 듣고 싶어함
  • 메모리 제약을 극복하고 JIT 컴파일된 버전을 실행하는 방법 탐구 제안

    • 추가적인 성능 향상을 가져올 수 있을 것임
  • AMD GPU와 ROCm 지원에 대한 경험 공유

    • 1년 전 AMD GPU와 ROCm 지원을 시도했으나, AMD가 NVIDIA를 따라잡기에는 아직 멀었다고 느낌
    • JAX를 선택한 것은 흥미로운 접근법이지만, PyTorch에서 벗어나는 데 어떤 어려움이 있었는지 궁금함
  • 405B 모델의 추론 측면에서 실험한 경험 공유

    • 'torch.cuda'가 그렇게 나쁘지 않다고 생각함
    • AMD 버전의 PyTorch가 이를 번역해주기 때문에 이름 문제일 뿐이라고 판단함
    • rocm:pytorch 컨테이너를 사용하는 것이 rocm:jax 컨테이너를 사용하는 것만큼 쉬움
    • 성능 데이터가 많이 게시되지 않았음을 지적함
    • MFU(모델 활용률) 수치를 궁금해함
  • 성능 데이터의 부재에 대한 질문

    • AMD GPU의 대량 주문으로 인해 가치를 추출할 가능성에 대한 의문 제기
    • "아니오"라는 인상을 받음
  • Obsidian(노트 테이킹 앱)이 왜 이 일을 하는지에 대한 의문

    • 처음에는 Obsidian의 게시물인 줄 알았음
    • GitHub.com과 GitHub.io를 아직 구분하지 않은 이유에 대한 의문
  • @dang에게 URL에 사용자 이름 포함 요청

    • 이 게시물은 Obsidian 자체가 아닌 사용자 생성 블로그에 관한 것임