GN⁺: AMD GPU로 Llama 405B 미세 조정
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (우리의 여정)
소개
- 오픈 소스 모델이 커지면서 대규모 AI 훈련을 처리할 강력한 인프라의 필요성이 커짐
- Felafax는 AMD GPU에서 LLaMA 3.1 405B 모델을 미세 조정하여 AMD 하드웨어의 효율성을 입증함
- 모든 작업을 GitHub에 오픈 소스로 공개함
- AMD MI300X GPU는 NVIDIA AI 하드웨어에 비해 높은 성능을 제공함
- TensorWave의 지원으로 프로젝트가 가능했음
JAX란 무엇이며 왜 선택했는가
- JAX는 NumPy와 유사한 API, 자동 미분, Google's XLA 컴파일러를 결합한 강력한 머신러닝 라이브러리임
- 모델 병렬 처리를 위한 우수한 API를 제공하여 대규모 모델 훈련에 이상적임
JAX의 장점
- 순수 함수: JAX는 순수 함수를 작성하도록 권장하여 코드의 구성, 디버깅, 읽기가 쉬워짐
- 고급 병렬 처리: JAX의 유연한 JIT API는 대규모 훈련에 필수적인 고급 데이터 및 모델 병렬 처리를 지원함
- 깨끗한 코드베이스: JAX의 설계 철학은 하드웨어 플랫폼 간에 이식 가능한 코드를 작성하도록 장려함
JAX가 비-NVIDIA 하드웨어에서 뛰어난 이유
- 하드웨어 독립적 접근: JAX는 XLA 컴파일러를 활용하여 하드웨어 독립적인 중간 표현으로 계산을 컴파일함
- 플랫폼 독립적 최적화: XLA 컴파일러는 하드웨어와 독립적으로 최적화를 수행함
- 간편한 이식성: JAX를 사용하면 NVIDIA에서 AMD로 전환할 때 코드 변경이 최소화됨
AMD GPU에서 JAX 설정
- Docker 이미지를 가져오고 컨테이너를 시작한 후 설치를 확인함
- AMD MI300x GPU 8개를 사용하여 LLaMA 405B 모델을 훈련함
LLaMA 405B 훈련: 성능 및 확장성
- JAX를 사용하여 AMD GPU에서 LLaMA 405B 모델을 훈련함
- LoRA 미세 조정을 통해 모델 가중치와 LoRA 매개변수를 bfloat16 정밀도로 조정함
- 모델 크기: 약 800GB의 VRAM을 차지함
- LoRA 가중치 및 옵티마이저 상태: 약 400GB의 VRAM을 차지함
- 총 VRAM 사용량: 약 1200GB
- 훈련 속도: 초당 약 35 토큰
- 메모리 효율성: 약 70% 유지
- 확장성: JAX를 사용하여 8개의 GPU에서 거의 선형적으로 확장됨
우리의 훈련 설정
- LLaMA 3.1을 PyTorch에서 JAX로 변환함
- 모델 로딩 및 매개변수 샤딩을 통해 효율적으로 분산함
JAX에서 매개변수 샤딩
- JAX의 디바이스 메쉬 기능을 사용하여 8개의 AMD GPU에 모델을 효율적으로 분산함
- 매개변수 샤딩 규칙을 정의하여 각 텐서의 차원을 메쉬 축에 따라 샤딩함
LoRA 훈련 구현
- LoRA는 가중치 업데이트를 저랭크 행렬로 분해하여 훈련 가능한 매개변수 수를 줄임
- LoRADense 레이어를 구현하여 LoRA 매개변수를 포함함
- LoRA 매개변수를 효율적으로 분산하여 메모리 사용량과 계산 효율성을 최적화함
결론
- AMD GPU와 JAX를 사용하여 LLaMA 3.1 405B 모델을 미세 조정하는 경험이 매우 긍정적이었음
- JAX의 강력한 병렬 처리 기능과 하드웨어 독립적 접근 방식을 활용하여 모델을 효율적으로 분산함
- AMD GPU가 대규모 AI 훈련을 위한 강력한 대안임을 입증함
- GitHub 저장소에서 전체 코드를 확인하고 직접 실행할 수 있음
GN⁺의 정리
- 이 기사는 AMD GPU와 JAX를 사용하여 대규모 AI 모델을 효율적으로 훈련하는 방법을 설명함
- AMD 하드웨어가 NVIDIA에 비해 비용 효율적인 대안임을 강조함
- JAX의 하드웨어 독립적 접근 방식이 코드 이식성을 높이고 유지보수를 용이하게 함
- 대규모 모델 훈련에 관심 있는 사람들에게 유용한 정보와 실습 코드를 제공함
- 유사한 기능을 가진 프로젝트로는 NVIDIA의 CUDA와 PyTorch가 있음
Hacker News 의견
-
JAX를 사용하여 Llama3.1 405B 모델을 8xAMD MI300x GPU에서 미세 조정한 성과 공유
- JAX의 고급 샤딩 API 덕분에 뛰어난 성능을 달성함
- 블로그 포스트와 오픈 소스 코드 링크 제공: GitHub 링크
- NVIDIA 하드웨어가 아닌 TPU, AMD, Trainium에서 LLM을 미세 조정하고 서비스하는 AI 인프라를 구축하는 스타트업임
- 많은 회사들이 AMD GPU에서 PyTorch를 작동시키려고 하지만, 이는 어려운 길이라고 판단함
- PyTorch는 NVIDIA 생태계와 깊이 연관되어 있어 비-NVIDIA 하드웨어에서 작동시키려면 많은 수정이 필요함
- JAX는 비-NVIDIA 하드웨어에 더 적합하다고 믿음
- JAX에서는 ML 모델 코드가 하드웨어 독립적인 HLO 그래프로 컴파일되고, XLA 컴파일러가 하드웨어 특정 최적화를 수행함
- 동일한 JAX 코드를 Google TPU와 AMD GPU에서 변경 없이 실행 가능함
- 회사 전략은 JAX로 모델을 포팅하고, XLA 커널을 활용해 비-NVIDIA 백엔드에서 최대 성능을 추출하는 것임
- Llama 3.1을 PyTorch에서 JAX로 처음 포팅했으며, 이제 동일한 JAX 모델이 TPU와 AMD GPU에서 잘 작동함
- 비전과 저장소에 대한 의견을 듣고 싶어함
-
메모리 제약을 극복하고 JIT 컴파일된 버전을 실행하는 방법 탐구 제안
- 추가적인 성능 향상을 가져올 수 있을 것임
-
AMD GPU와 ROCm 지원에 대한 경험 공유
- 1년 전 AMD GPU와 ROCm 지원을 시도했으나, AMD가 NVIDIA를 따라잡기에는 아직 멀었다고 느낌
- JAX를 선택한 것은 흥미로운 접근법이지만, PyTorch에서 벗어나는 데 어떤 어려움이 있었는지 궁금함
-
405B 모델의 추론 측면에서 실험한 경험 공유
- 'torch.cuda'가 그렇게 나쁘지 않다고 생각함
- AMD 버전의 PyTorch가 이를 번역해주기 때문에 이름 문제일 뿐이라고 판단함
- rocm:pytorch 컨테이너를 사용하는 것이 rocm:jax 컨테이너를 사용하는 것만큼 쉬움
- 성능 데이터가 많이 게시되지 않았음을 지적함
- MFU(모델 활용률) 수치를 궁금해함
-
성능 데이터의 부재에 대한 질문
- AMD GPU의 대량 주문으로 인해 가치를 추출할 가능성에 대한 의문 제기
- "아니오"라는 인상을 받음
-
Obsidian(노트 테이킹 앱)이 왜 이 일을 하는지에 대한 의문
- 처음에는 Obsidian의 게시물인 줄 알았음
- GitHub.com과 GitHub.io를 아직 구분하지 않은 이유에 대한 의문
-
@dang에게 URL에 사용자 이름 포함 요청
- 이 게시물은 Obsidian 자체가 아닌 사용자 생성 블로그에 관한 것임