▲GN⁺ 2024-09-24 | parent | ★ favorite | on: AMD GPU로 Llama 405B 미세 조정(publish.obsidian.md)Hacker News 의견 JAX를 사용하여 Llama3.1 405B 모델을 8xAMD MI300x GPU에서 미세 조정한 성과 공유 JAX의 고급 샤딩 API 덕분에 뛰어난 성능을 달성함 블로그 포스트와 오픈 소스 코드 링크 제공: GitHub 링크 NVIDIA 하드웨어가 아닌 TPU, AMD, Trainium에서 LLM을 미세 조정하고 서비스하는 AI 인프라를 구축하는 스타트업임 많은 회사들이 AMD GPU에서 PyTorch를 작동시키려고 하지만, 이는 어려운 길이라고 판단함 PyTorch는 NVIDIA 생태계와 깊이 연관되어 있어 비-NVIDIA 하드웨어에서 작동시키려면 많은 수정이 필요함 JAX는 비-NVIDIA 하드웨어에 더 적합하다고 믿음 JAX에서는 ML 모델 코드가 하드웨어 독립적인 HLO 그래프로 컴파일되고, XLA 컴파일러가 하드웨어 특정 최적화를 수행함 동일한 JAX 코드를 Google TPU와 AMD GPU에서 변경 없이 실행 가능함 회사 전략은 JAX로 모델을 포팅하고, XLA 커널을 활용해 비-NVIDIA 백엔드에서 최대 성능을 추출하는 것임 Llama 3.1을 PyTorch에서 JAX로 처음 포팅했으며, 이제 동일한 JAX 모델이 TPU와 AMD GPU에서 잘 작동함 비전과 저장소에 대한 의견을 듣고 싶어함 메모리 제약을 극복하고 JIT 컴파일된 버전을 실행하는 방법 탐구 제안 추가적인 성능 향상을 가져올 수 있을 것임 AMD GPU와 ROCm 지원에 대한 경험 공유 1년 전 AMD GPU와 ROCm 지원을 시도했으나, AMD가 NVIDIA를 따라잡기에는 아직 멀었다고 느낌 JAX를 선택한 것은 흥미로운 접근법이지만, PyTorch에서 벗어나는 데 어떤 어려움이 있었는지 궁금함 405B 모델의 추론 측면에서 실험한 경험 공유 'torch.cuda'가 그렇게 나쁘지 않다고 생각함 AMD 버전의 PyTorch가 이를 번역해주기 때문에 이름 문제일 뿐이라고 판단함 rocm:pytorch 컨테이너를 사용하는 것이 rocm:jax 컨테이너를 사용하는 것만큼 쉬움 성능 데이터가 많이 게시되지 않았음을 지적함 MFU(모델 활용률) 수치를 궁금해함 성능 데이터의 부재에 대한 질문 AMD GPU의 대량 주문으로 인해 가치를 추출할 가능성에 대한 의문 제기 "아니오"라는 인상을 받음 Obsidian(노트 테이킹 앱)이 왜 이 일을 하는지에 대한 의문 처음에는 Obsidian의 게시물인 줄 알았음 GitHub.com과 GitHub.io를 아직 구분하지 않은 이유에 대한 의문 @dang에게 URL에 사용자 이름 포함 요청 이 게시물은 Obsidian 자체가 아닌 사용자 생성 블로그에 관한 것임
Hacker News 의견
JAX를 사용하여 Llama3.1 405B 모델을 8xAMD MI300x GPU에서 미세 조정한 성과 공유
메모리 제약을 극복하고 JIT 컴파일된 버전을 실행하는 방법 탐구 제안
AMD GPU와 ROCm 지원에 대한 경험 공유
405B 모델의 추론 측면에서 실험한 경험 공유
성능 데이터의 부재에 대한 질문
Obsidian(노트 테이킹 앱)이 왜 이 일을 하는지에 대한 의문
@dang에게 URL에 사용자 이름 포함 요청