# 당신의 모델을 스케일 하는 법: TPU에서의 LLM에 대한 시스템적 관점

> Clean Markdown view of GeekNews topic #19097. Use the original source for factual precision when an external source URL is present.

## Metadata

- GeekNews HTML: [https://news.hada.io/topic?id=19097](https://news.hada.io/topic?id=19097)
- GeekNews Markdown: [https://news.hada.io/topic/19097.md](https://news.hada.io/topic/19097.md)
- Type: GN+
- Author: [neo](https://news.hada.io/@neo)
- Published: 2025-02-07T06:12:34+09:00
- Updated: 2025-02-07T06:12:34+09:00
- Original source: [jax-ml.github.io](https://jax-ml.github.io/scaling-book/)
- Points: 7
- Comments: 1

## Summary

딥러닝 모델의 대규모 최적화는 단순한 원칙을 통해 모델 효율을 높일 수 있으며, 이는 단일 가속기부터 수만 개의 가속기까지 적용 가능합니다. 이 책은 TPU와 GPU 하드웨어의 작동 방식과 Transformer 아키텍처가 하드웨어에서 어떻게 최적화되어 왔는지를 설명하며, 모델의 병렬화와 효율적인 실행을 위한 다양한 기법을 소개합니다. 또한, LLaMA-3 모델을 TPU에서 훈련하고 추론하는 과정을 통해 실제 비용과 시간, 구성 방식을 제시하며, JAX를 활용한 프로파일링 및 병렬 처리 방법도 다룹니다.

## Topic Body

- 딥러닝 성능을 대규모로 최적화하는 것은 ‘연금술’처럼 보이지만, 실제로는 이해 가능한 단순한 원칙으로 모델 효율을 높일 수 있음  
- 단일 가속기부터 수만 개의 가속기까지 비교적 간단한 원칙이 모든 곳에 적용되며, 이를 이해함으로써 다음과 같은 유용한 작업 수행이 가능함:  
  - 모델의 각 부분이 이론적 최적값에 얼마나 근접했는지 대략적으로 파악  
  - 다양한 스케일에서 여러 병렬화 기법을 선택하는 근거를 마련할 수 있음  
  - 대형 Transformer 모델의 학습 및 실행에 필요한 비용과 시간 추정  
  - 특정 하드웨어의 특성을 활용하는 알고리즘 설계  
  - 현재 알고리즘 성능의 한계를 명확히 이해하여 하드웨어 설계  
- 필요한 배경 지식  
  - LLM과 Transformer 아키텍처에 대한 기본 개념 이해 필요  
  - 대규모 운영 방식에 대한 이해는 필수가 아님  
  - LLM 훈련 기본 지식과 JAX 사용 경험이 있다면 더 좋음  
  - Transformer 아키텍처에 대한 블로그 포스트와 JAX의 LLM 스케일링에 대한 슬라이드 참고 권장  
- 목표  
  - 모델을 주어진 하드웨어에서 어떤 방식으로 병렬화하면 좋을지 추정할 수 있는 역량을 기르는 것  
  - 훈련과 추론에 걸리는 시간과 비용을 대략적으로 계산할 수 있는 능력을 기르는 것  
  
### 왜 관심을 가져야 하는가  
- 3~4년 전만 해도 ML 연구자 대부분은 이런 대규모 스케일 최적화에 대해 깊이 알 필요가 없었음  
  - 현재는 “작은” 모델조차도 하드웨어 한계에 근접해서 동작하기 때문에, 효율적인 대규모 작업 방식 이해가 필수적이 됨  
  - ML 역사는 시스템 혁신과 소프트웨어 개선이 교차 발전해 온 흐름으로 볼 수 있음  
  - 최근 Transformer 모델들이 하드웨어 한계까지 사용함에 따라, 모델 효율성을 이해하지 못하면 새로운 아키텍처나 연구가 실제 적용에서 실패할 가능성이 높음  
  - 벤치마크에서 20% 성능 향상을 얻어도, 하드웨어 효율이 20% 떨어지면 결국 실용성이 낮아짐  
- 모델 스케일링의 핵심 목표는 칩(가속기)의 수를 늘릴 때 선형적으로 처리량이 증가하도록 만드는 것임  
  - 이를 "강한 스케일링"이라고 함  
  - 칩 추가는 계산 시간을 줄이지만 칩 간 통신 비용이 발생  
  - 통신이 계산보다 오래 걸리면 "통신 제한(Communication Bound)" 상태가 되어 강한 스케일링 불가능  
  - 하드웨어를 충분히 이해하여 이러한 병목 현상이 발생할 위치를 예측할 수 있다면, 이를 방지하도록 모델을 설계하거나 재구성할 수 있음   
- 이 책의 목표는 **TPU(및 GPU) 하드웨어의 작동 방식과 Transformer 아키텍처가 현재 하드웨어에서 잘 작동하도록 어떻게 발전해왔는지를 설명하는 것**  
  - 새로운 아키텍처를 설계하는 연구자와 현 세대의 LLM을 빠르게 실행하기 위해 노력하는 엔지니어 모두에게 도움이 되기를 바람   
  
### 전체 개요  
- 이 글은 다음과 같이 구성됨  
- [섹션 1](https://jax-ml.github.io/scaling-book/roofline)에서는 roofline 분석을 통해 모델의 성능 한계를 결정하는 요소(통신, 연산, 메모리)를 설명함  
- [섹션 2](https://jax-ml.github.io/scaling-book/tpus), [섹션 3](https://jax-ml.github.io/scaling-book/sharding)에서는 TPU와 GPU의 내부 구조 및 칩 간 연결 방식을 다룸  
  - 이를 통해 아래와 같은 질문에 답변함  
    - 특정 크기의 매트릭스 곱셈은 이론적으로 얼마나 빨리 수행될 수 있는가  
    - 어느 지점에서 연산이 메모리 대역폭이나 통신 대역폭에 묶이게 되는가  
    - TPU 클러스터는 어떤 구조로 연결되고, 한 칩에서 다른 칩으로 데이터를 옮길 때 걸리는 대략적인 시간은 얼마인가  
    - 분산된 매트릭스를 어떻게 효율적으로 곱셈할 수 있는가  
- [섹션 4](https://jax-ml.github.io/scaling-book/transformers)에서는 Transformer 아키텍처의 수식(매트릭스 크기, 파라미터 수, FLOPs)을 자세히 다룸  
- [섹션 5](https://jax-ml.github.io/scaling-book/training)와 [섹션 7](https://jax-ml.github.io/scaling-book/inference)이 핵심으로, 여러 칩에 모델을 병렬화하는 다양한 방법을 소개함  
  - Data parallel, Tensor parallel, Pipeline parallel, Expert parallel  
  - ZeRO, Rematerialisation, Host offload, Gradient accumulation 등 메모리 절감 기법도 다룸  
- [섹션 6](https://jax-ml.github.io/scaling-book/applied-training), [섹션 8](https://jax-ml.github.io/scaling-book/applied-inference)는 LLaMA-3 모델을 TPU에서 훈련하고 추론하는 과정을 예시로 들어 실제 비용과 시간, 구성 방식을 제시함  
- 마지막으로 [섹션 9](https://jax-ml.github.io/scaling-book/profiling), [섹션 10](https://jax-ml.github.io/scaling-book/jax-stuff)는 JAX에서 모델을 프로파일하고, 디버그하며, 병렬 처리를 적용하는 실제 방법을 다룸  
  
### 자세한 내용 :책의 주요 섹션을 요약  
- 파트 1: Preliminaries  
  - [섹션 1: 간단한 Roofline 분석 소개](https://jax-ml.github.io/scaling-book/roofline)  
    - 알고리즘을 제약하는 세 가지 요소: 연산, 통신, 메모리  
    - 이로부터 연산 속도의 상한선을 추정하는 방법을 배움  
  
  - [섹션 2: TPU를 바라보는 관점](https://jax-ml.github.io/scaling-book/tpus)  
    - TPU가 어떤 식으로 연산하는지  
    - Systolic array 구조가 무엇인지  
    - TPU가 메모리와 통신 대역폭을 어떻게 제공하는지에 대한 기본적인 이해  
  
  - [섹션 3: 분산 매트릭스와 분산 곱셈](https://jax-ml.github.io/scaling-book/sharding)  
    - 모델 파라미터를 여러 칩에 나누어 저장(Sharding)하는 기법  
    - 분산된 매트릭스 연산 시 발생하는 통신과 병목을 다루는 방식  
  
- 파트 2: Transformers  
  
  - [섹션 4: 필요한 Transformer 수식 정리](https://jax-ml.github.io/scaling-book/transformers)  
    - Transformer에서 매트릭스 곱셈이 구체적으로 어떤 형태인지  
    - 파라미터 수, FLOPs, KV 캐시 크기 등을 계산하는 방법  
    - Attention 연산이 Feed-Forward 블록 대비 얼마나 많은 연산을 요구하는지 파악  
  
  - [섹션 5: Transformer 훈련 병렬화 전략](https://jax-ml.github.io/scaling-book/training)  
    - Data parallel, Tensor parallel, Pipeline parallel, Expert parallel 기법 소개  
    - ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload 등 메모리 절감 방안  
    - 특정 모델 크기와 칩 수에 맞춰 병렬화를 구성하는 개념 정립  
  
  - [섹션 6: LLaMA 3 TPU 훈련 적용](https://jax-ml.github.io/scaling-book/applied-training)  
    - 실제 TPU 환경에서 LLaMA 3 모델을 훈련한다고 가정할 때, 소요 시간과 비용 추정  
    - 배치 사이즈, 병렬화 방식, 메모리 사용량 등에 대한 구체적인 예시 제시  
  
  - [섹션 7: Transformer 추론에 대한 모든 것](https://jax-ml.github.io/scaling-book/inference)  
    - 추론 시에는 지연(latency)이 중요한 신규 요인으로 등장  
    - KV 캐시 등으로 인한 메모리 사용과 통신 문제  
    - 모델 서빙을 위해 여러 칩을 어떻게 배분하고 연결할 것인지에 대한 논의  
  
  - [섹션 8: LLaMA 3 TPU 서빙 적용](https://jax-ml.github.io/scaling-book/applied-inference)  
    - TPU v5e에서 LLaMA 3를 서빙한다고 가정할 때, 대략적인 비용과 지연, 처리량 트레이드오프 분석  
  
- 파트 3: Practical Tutorials   
  
  - [섹션 9: TPU 코드 프로파일링 방법](https://jax-ml.github.io/scaling-book/profiling)  
    - JAX+XLA 스택 이해  
    - 실제 성능 저하 이슈 파악과 해결책  
    - JAX/TensorBoard 프로파일러 사용법  
  
  - [섹션 10: JAX로 TPU 프로그래밍하기](https://jax-ml.github.io/scaling-book/jax-stuff)  
    - JAX의 병렬화 API(primitives) 활용법  
    - 예제와 문제를 통해 병렬 연산 개념을 익힘  
  
  - [섹션 11: 결론과 추가 자료](https://jax-ml.github.io/scaling-book/conclusion)  
    - TPU와 LLM에 대한 추가 읽을거리  
    - 전체 내용을 간략히 마무리하며, 미래 전망 언급

## Comments



### Comment 34270

- Author: neo
- Created: 2025-02-07T06:12:36+09:00
- Points: 1

###### [Hacker News 의견](https://news.ycombinator.com/item?id=42936910) 
- JAX가 앞으로 몇 년 동안 pytorch/cuda를 대체할 것이라는 기대가 있음. Deepseek 팀과의 PTX 문제는 하드웨어 성능을 최대한 활용하기 위해 더 낮은 수준의 접근 방식에 투자하는 것의 가치를 보여줌
  - Google 내부에서 성능 작업의 지침서로 사용되었음. 공개된 것이 놀랍지만, Gemini 관련 세부 사항은 제거된 것으로 보임
  - 이 가이드는 JAX/XLA 덕분에 GPU로 직접 전환할 수 있는 점이 좋음
  - JAX가 왜 AST 대신 트레이싱을 사용하는지 궁금해하는 의견이 있음
  - 작성자의 트윗 스레드 링크가 공유됨
  - Jekyll 사이트를 PDF로 변환할 방법을 찾고 있는 사람 있음
  - 훌륭한 글이라는 칭찬과 감사의 표현이 있음
  - 멋진 애니메이션을 어떻게 만드는지 궁금해하는 의견이 있음
