GN⁺: 당신의 모델을 스케일 하는 법: TPU에서의 LLM에 대한 시스템적 관점
(jax-ml.github.io)- 딥러닝 성능을 대규모로 최적화하는 것은 ‘연금술’처럼 보이지만, 실제로는 이해 가능한 단순한 원칙으로 모델 효율을 높일 수 있음
- 단일 가속기부터 수만 개의 가속기까지 비교적 간단한 원칙이 모든 곳에 적용되며, 이를 이해함으로써 다음과 같은 유용한 작업 수행이 가능함:
- 모델의 각 부분이 이론적 최적값에 얼마나 근접했는지 대략적으로 파악
- 다양한 스케일에서 여러 병렬화 기법을 선택하는 근거를 마련할 수 있음
- 대형 Transformer 모델의 학습 및 실행에 필요한 비용과 시간 추정
- 특정 하드웨어의 특성을 활용하는 알고리즘 설계
- 현재 알고리즘 성능의 한계를 명확히 이해하여 하드웨어 설계
- 필요한 배경 지식
- LLM과 Transformer 아키텍처에 대한 기본 개념 이해 필요
- 대규모 운영 방식에 대한 이해는 필수가 아님
- LLM 훈련 기본 지식과 JAX 사용 경험이 있다면 더 좋음
- Transformer 아키텍처에 대한 블로그 포스트와 JAX의 LLM 스케일링에 대한 슬라이드 참고 권장
- 목표
- 모델을 주어진 하드웨어에서 어떤 방식으로 병렬화하면 좋을지 추정할 수 있는 역량을 기르는 것
- 훈련과 추론에 걸리는 시간과 비용을 대략적으로 계산할 수 있는 능력을 기르는 것
왜 관심을 가져야 하는가
- 3~4년 전만 해도 ML 연구자 대부분은 이런 대규모 스케일 최적화에 대해 깊이 알 필요가 없었음
- 현재는 “작은” 모델조차도 하드웨어 한계에 근접해서 동작하기 때문에, 효율적인 대규모 작업 방식 이해가 필수적이 됨
- ML 역사는 시스템 혁신과 소프트웨어 개선이 교차 발전해 온 흐름으로 볼 수 있음
- 최근 Transformer 모델들이 하드웨어 한계까지 사용함에 따라, 모델 효율성을 이해하지 못하면 새로운 아키텍처나 연구가 실제 적용에서 실패할 가능성이 높음
- 벤치마크에서 20% 성능 향상을 얻어도, 하드웨어 효율이 20% 떨어지면 결국 실용성이 낮아짐
- 모델 스케일링의 핵심 목표는 칩(가속기)의 수를 늘릴 때 선형적으로 처리량이 증가하도록 만드는 것임
- 이를 "강한 스케일링"이라고 함
- 칩 추가는 계산 시간을 줄이지만 칩 간 통신 비용이 발생
- 통신이 계산보다 오래 걸리면 "통신 제한(Communication Bound)" 상태가 되어 강한 스케일링 불가능
- 하드웨어를 충분히 이해하여 이러한 병목 현상이 발생할 위치를 예측할 수 있다면, 이를 방지하도록 모델을 설계하거나 재구성할 수 있음
- 이 책의 목표는 TPU(및 GPU) 하드웨어의 작동 방식과 Transformer 아키텍처가 현재 하드웨어에서 잘 작동하도록 어떻게 발전해왔는지를 설명하는 것
- 새로운 아키텍처를 설계하는 연구자와 현 세대의 LLM을 빠르게 실행하기 위해 노력하는 엔지니어 모두에게 도움이 되기를 바람
전체 개요
- 이 글은 다음과 같이 구성됨
- 섹션 1에서는 roofline 분석을 통해 모델의 성능 한계를 결정하는 요소(통신, 연산, 메모리)를 설명함
-
섹션 2, 섹션 3에서는 TPU와 GPU의 내부 구조 및 칩 간 연결 방식을 다룸
- 이를 통해 아래와 같은 질문에 답변함
- 특정 크기의 매트릭스 곱셈은 이론적으로 얼마나 빨리 수행될 수 있는가
- 어느 지점에서 연산이 메모리 대역폭이나 통신 대역폭에 묶이게 되는가
- TPU 클러스터는 어떤 구조로 연결되고, 한 칩에서 다른 칩으로 데이터를 옮길 때 걸리는 대략적인 시간은 얼마인가
- 분산된 매트릭스를 어떻게 효율적으로 곱셈할 수 있는가
- 이를 통해 아래와 같은 질문에 답변함
- 섹션 4에서는 Transformer 아키텍처의 수식(매트릭스 크기, 파라미터 수, FLOPs)을 자세히 다룸
-
섹션 5와 섹션 7이 핵심으로, 여러 칩에 모델을 병렬화하는 다양한 방법을 소개함
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- ZeRO, Rematerialisation, Host offload, Gradient accumulation 등 메모리 절감 기법도 다룸
- 섹션 6, 섹션 8는 LLaMA-3 모델을 TPU에서 훈련하고 추론하는 과정을 예시로 들어 실제 비용과 시간, 구성 방식을 제시함
- 마지막으로 섹션 9, 섹션 10는 JAX에서 모델을 프로파일하고, 디버그하며, 병렬 처리를 적용하는 실제 방법을 다룸
자세한 내용 :책의 주요 섹션을 요약
-
파트 1: Preliminaries
-
- 알고리즘을 제약하는 세 가지 요소: 연산, 통신, 메모리
- 이로부터 연산 속도의 상한선을 추정하는 방법을 배움
-
- TPU가 어떤 식으로 연산하는지
- Systolic array 구조가 무엇인지
- TPU가 메모리와 통신 대역폭을 어떻게 제공하는지에 대한 기본적인 이해
-
- 모델 파라미터를 여러 칩에 나누어 저장(Sharding)하는 기법
- 분산된 매트릭스 연산 시 발생하는 통신과 병목을 다루는 방식
-
-
파트 2: Transformers
-
- Transformer에서 매트릭스 곱셈이 구체적으로 어떤 형태인지
- 파라미터 수, FLOPs, KV 캐시 크기 등을 계산하는 방법
- Attention 연산이 Feed-Forward 블록 대비 얼마나 많은 연산을 요구하는지 파악
-
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel 기법 소개
- ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload 등 메모리 절감 방안
- 특정 모델 크기와 칩 수에 맞춰 병렬화를 구성하는 개념 정립
-
- 실제 TPU 환경에서 LLaMA 3 모델을 훈련한다고 가정할 때, 소요 시간과 비용 추정
- 배치 사이즈, 병렬화 방식, 메모리 사용량 등에 대한 구체적인 예시 제시
-
- 추론 시에는 지연(latency)이 중요한 신규 요인으로 등장
- KV 캐시 등으로 인한 메모리 사용과 통신 문제
- 모델 서빙을 위해 여러 칩을 어떻게 배분하고 연결할 것인지에 대한 논의
-
- TPU v5e에서 LLaMA 3를 서빙한다고 가정할 때, 대략적인 비용과 지연, 처리량 트레이드오프 분석
-
-
파트 3: Practical Tutorials
-
- JAX+XLA 스택 이해
- 실제 성능 저하 이슈 파악과 해결책
- JAX/TensorBoard 프로파일러 사용법
-
- JAX의 병렬화 API(primitives) 활용법
- 예제와 문제를 통해 병렬 연산 개념을 익힘
-
- TPU와 LLM에 대한 추가 읽을거리
- 전체 내용을 간략히 마무리하며, 미래 전망 언급
-
Hacker News 의견
- JAX가 앞으로 몇 년 동안 pytorch/cuda를 대체할 것이라는 기대가 있음. Deepseek 팀과의 PTX 문제는 하드웨어 성능을 최대한 활용하기 위해 더 낮은 수준의 접근 방식에 투자하는 것의 가치를 보여줌
- Google 내부에서 성능 작업의 지침서로 사용되었음. 공개된 것이 놀랍지만, Gemini 관련 세부 사항은 제거된 것으로 보임
- 이 가이드는 JAX/XLA 덕분에 GPU로 직접 전환할 수 있는 점이 좋음
- JAX가 왜 AST 대신 트레이싱을 사용하는지 궁금해하는 의견이 있음
- 작성자의 트윗 스레드 링크가 공유됨
- Jekyll 사이트를 PDF로 변환할 방법을 찾고 있는 사람 있음
- 훌륭한 글이라는 칭찬과 감사의 표현이 있음
- 멋진 애니메이션을 어떻게 만드는지 궁금해하는 의견이 있음