GN⁺ 2024-09-24 | parent | ★ favorite | on: AMD GPU로 Llama 405B 미세 조정(publish.obsidian.md)
Hacker News 의견
  • JAX를 사용하여 Llama3.1 405B 모델을 8xAMD MI300x GPU에서 미세 조정한 성과 공유

    • JAX의 고급 샤딩 API 덕분에 뛰어난 성능을 달성함
    • 블로그 포스트와 오픈 소스 코드 링크 제공: GitHub 링크
    • NVIDIA 하드웨어가 아닌 TPU, AMD, Trainium에서 LLM을 미세 조정하고 서비스하는 AI 인프라를 구축하는 스타트업임
    • 많은 회사들이 AMD GPU에서 PyTorch를 작동시키려고 하지만, 이는 어려운 길이라고 판단함
    • PyTorch는 NVIDIA 생태계와 깊이 연관되어 있어 비-NVIDIA 하드웨어에서 작동시키려면 많은 수정이 필요함
    • JAX는 비-NVIDIA 하드웨어에 더 적합하다고 믿음
    • JAX에서는 ML 모델 코드가 하드웨어 독립적인 HLO 그래프로 컴파일되고, XLA 컴파일러가 하드웨어 특정 최적화를 수행함
    • 동일한 JAX 코드를 Google TPU와 AMD GPU에서 변경 없이 실행 가능함
    • 회사 전략은 JAX로 모델을 포팅하고, XLA 커널을 활용해 비-NVIDIA 백엔드에서 최대 성능을 추출하는 것임
    • Llama 3.1을 PyTorch에서 JAX로 처음 포팅했으며, 이제 동일한 JAX 모델이 TPU와 AMD GPU에서 잘 작동함
    • 비전과 저장소에 대한 의견을 듣고 싶어함
  • 메모리 제약을 극복하고 JIT 컴파일된 버전을 실행하는 방법 탐구 제안

    • 추가적인 성능 향상을 가져올 수 있을 것임
  • AMD GPU와 ROCm 지원에 대한 경험 공유

    • 1년 전 AMD GPU와 ROCm 지원을 시도했으나, AMD가 NVIDIA를 따라잡기에는 아직 멀었다고 느낌
    • JAX를 선택한 것은 흥미로운 접근법이지만, PyTorch에서 벗어나는 데 어떤 어려움이 있었는지 궁금함
  • 405B 모델의 추론 측면에서 실험한 경험 공유

    • 'torch.cuda'가 그렇게 나쁘지 않다고 생각함
    • AMD 버전의 PyTorch가 이를 번역해주기 때문에 이름 문제일 뿐이라고 판단함
    • rocm:pytorch 컨테이너를 사용하는 것이 rocm:jax 컨테이너를 사용하는 것만큼 쉬움
    • 성능 데이터가 많이 게시되지 않았음을 지적함
    • MFU(모델 활용률) 수치를 궁금해함
  • 성능 데이터의 부재에 대한 질문

    • AMD GPU의 대량 주문으로 인해 가치를 추출할 가능성에 대한 의문 제기
    • "아니오"라는 인상을 받음
  • Obsidian(노트 테이킹 앱)이 왜 이 일을 하는지에 대한 의문

    • 처음에는 Obsidian의 게시물인 줄 알았음
    • GitHub.com과 GitHub.io를 아직 구분하지 않은 이유에 대한 의문
  • @dang에게 URL에 사용자 이름 포함 요청

    • 이 게시물은 Obsidian 자체가 아닌 사용자 생성 블로그에 관한 것임