6P by neo 14일전 | ★ favorite | 댓글 1개
  • 딥러닝 성능을 대규모로 최적화하는 것은 ‘연금술’처럼 보이지만, 실제로는 이해 가능한 단순한 원칙으로 모델 효율을 높일 수 있음
  • 단일 가속기부터 수만 개의 가속기까지 비교적 간단한 원칙이 모든 곳에 적용되며, 이를 이해함으로써 다음과 같은 유용한 작업 수행이 가능함:
    • 모델의 각 부분이 이론적 최적값에 얼마나 근접했는지 대략적으로 파악
    • 다양한 스케일에서 여러 병렬화 기법을 선택하는 근거를 마련할 수 있음
    • 대형 Transformer 모델의 학습 및 실행에 필요한 비용과 시간 추정
    • 특정 하드웨어의 특성을 활용하는 알고리즘 설계
    • 현재 알고리즘 성능의 한계를 명확히 이해하여 하드웨어 설계
  • 필요한 배경 지식
    • LLM과 Transformer 아키텍처에 대한 기본 개념 이해 필요
    • 대규모 운영 방식에 대한 이해는 필수가 아님
    • LLM 훈련 기본 지식과 JAX 사용 경험이 있다면 더 좋음
    • Transformer 아키텍처에 대한 블로그 포스트와 JAX의 LLM 스케일링에 대한 슬라이드 참고 권장
  • 목표
    • 모델을 주어진 하드웨어에서 어떤 방식으로 병렬화하면 좋을지 추정할 수 있는 역량을 기르는 것
    • 훈련과 추론에 걸리는 시간과 비용을 대략적으로 계산할 수 있는 능력을 기르는 것

왜 관심을 가져야 하는가

  • 3~4년 전만 해도 ML 연구자 대부분은 이런 대규모 스케일 최적화에 대해 깊이 알 필요가 없었음
    • 현재는 “작은” 모델조차도 하드웨어 한계에 근접해서 동작하기 때문에, 효율적인 대규모 작업 방식 이해가 필수적이 됨
    • ML 역사는 시스템 혁신과 소프트웨어 개선이 교차 발전해 온 흐름으로 볼 수 있음
    • 최근 Transformer 모델들이 하드웨어 한계까지 사용함에 따라, 모델 효율성을 이해하지 못하면 새로운 아키텍처나 연구가 실제 적용에서 실패할 가능성이 높음
    • 벤치마크에서 20% 성능 향상을 얻어도, 하드웨어 효율이 20% 떨어지면 결국 실용성이 낮아짐
  • 모델 스케일링의 핵심 목표는 칩(가속기)의 수를 늘릴 때 선형적으로 처리량이 증가하도록 만드는 것임
    • 이를 "강한 스케일링"이라고 함
    • 칩 추가는 계산 시간을 줄이지만 칩 간 통신 비용이 발생
    • 통신이 계산보다 오래 걸리면 "통신 제한(Communication Bound)" 상태가 되어 강한 스케일링 불가능
    • 하드웨어를 충분히 이해하여 이러한 병목 현상이 발생할 위치를 예측할 수 있다면, 이를 방지하도록 모델을 설계하거나 재구성할 수 있음
  • 이 책의 목표는 TPU(및 GPU) 하드웨어의 작동 방식과 Transformer 아키텍처가 현재 하드웨어에서 잘 작동하도록 어떻게 발전해왔는지를 설명하는 것
    • 새로운 아키텍처를 설계하는 연구자와 현 세대의 LLM을 빠르게 실행하기 위해 노력하는 엔지니어 모두에게 도움이 되기를 바람

전체 개요

  • 이 글은 다음과 같이 구성됨
  • 섹션 1에서는 roofline 분석을 통해 모델의 성능 한계를 결정하는 요소(통신, 연산, 메모리)를 설명함
  • 섹션 2, 섹션 3에서는 TPU와 GPU의 내부 구조 및 칩 간 연결 방식을 다룸
    • 이를 통해 아래와 같은 질문에 답변함
      • 특정 크기의 매트릭스 곱셈은 이론적으로 얼마나 빨리 수행될 수 있는가
      • 어느 지점에서 연산이 메모리 대역폭이나 통신 대역폭에 묶이게 되는가
      • TPU 클러스터는 어떤 구조로 연결되고, 한 칩에서 다른 칩으로 데이터를 옮길 때 걸리는 대략적인 시간은 얼마인가
      • 분산된 매트릭스를 어떻게 효율적으로 곱셈할 수 있는가
  • 섹션 4에서는 Transformer 아키텍처의 수식(매트릭스 크기, 파라미터 수, FLOPs)을 자세히 다룸
  • 섹션 5섹션 7이 핵심으로, 여러 칩에 모델을 병렬화하는 다양한 방법을 소개함
    • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
    • ZeRO, Rematerialisation, Host offload, Gradient accumulation 등 메모리 절감 기법도 다룸
  • 섹션 6, 섹션 8는 LLaMA-3 모델을 TPU에서 훈련하고 추론하는 과정을 예시로 들어 실제 비용과 시간, 구성 방식을 제시함
  • 마지막으로 섹션 9, 섹션 10는 JAX에서 모델을 프로파일하고, 디버그하며, 병렬 처리를 적용하는 실제 방법을 다룸

자세한 내용 :책의 주요 섹션을 요약

  • 파트 1: Preliminaries

  • 파트 2: Transformers

    • 섹션 4: 필요한 Transformer 수식 정리

      • Transformer에서 매트릭스 곱셈이 구체적으로 어떤 형태인지
      • 파라미터 수, FLOPs, KV 캐시 크기 등을 계산하는 방법
      • Attention 연산이 Feed-Forward 블록 대비 얼마나 많은 연산을 요구하는지 파악
    • 섹션 5: Transformer 훈련 병렬화 전략

      • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel 기법 소개
      • ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload 등 메모리 절감 방안
      • 특정 모델 크기와 칩 수에 맞춰 병렬화를 구성하는 개념 정립
    • 섹션 6: LLaMA 3 TPU 훈련 적용

      • 실제 TPU 환경에서 LLaMA 3 모델을 훈련한다고 가정할 때, 소요 시간과 비용 추정
      • 배치 사이즈, 병렬화 방식, 메모리 사용량 등에 대한 구체적인 예시 제시
    • 섹션 7: Transformer 추론에 대한 모든 것

      • 추론 시에는 지연(latency)이 중요한 신규 요인으로 등장
      • KV 캐시 등으로 인한 메모리 사용과 통신 문제
      • 모델 서빙을 위해 여러 칩을 어떻게 배분하고 연결할 것인지에 대한 논의
    • 섹션 8: LLaMA 3 TPU 서빙 적용

      • TPU v5e에서 LLaMA 3를 서빙한다고 가정할 때, 대략적인 비용과 지연, 처리량 트레이드오프 분석
  • 파트 3: Practical Tutorials

Hacker News 의견
  • JAX가 앞으로 몇 년 동안 pytorch/cuda를 대체할 것이라는 기대가 있음. Deepseek 팀과의 PTX 문제는 하드웨어 성능을 최대한 활용하기 위해 더 낮은 수준의 접근 방식에 투자하는 것의 가치를 보여줌
    • Google 내부에서 성능 작업의 지침서로 사용되었음. 공개된 것이 놀랍지만, Gemini 관련 세부 사항은 제거된 것으로 보임
    • 이 가이드는 JAX/XLA 덕분에 GPU로 직접 전환할 수 있는 점이 좋음
    • JAX가 왜 AST 대신 트레이싱을 사용하는지 궁금해하는 의견이 있음
    • 작성자의 트윗 스레드 링크가 공유됨
    • Jekyll 사이트를 PDF로 변환할 방법을 찾고 있는 사람 있음
    • 훌륭한 글이라는 칭찬과 감사의 표현이 있음
    • 멋진 애니메이션을 어떻게 만드는지 궁금해하는 의견이 있음