순수 NumPy로 구현하는 라마 3 모델
(docs.likejazz.com)- Llama 3 모델의 실제 동작 가능한 구현을 통해 정확한 구조 이해 하기
 
개요
- Meta에서 공개한 Llama 3 모델이 주목받고 있음.
 - 24K GPUs, 15T 훈련 데이터, 10M 명령 데이터, 1.3M GPU 시간 등 압도적인 스케일과 성능을 자랑함.
 - 모델 구조는 크게 변하지 않았음. Llama 3는 GQA를 사용하지만, 이는 Llama 2 70B에서도 이미 구현된 바 있음.
 - NumPy만을 사용하여 모델 구조를 직관적으로 이해할 수 있도록 구현함.
 - Andrej Karpathy가 Llama 2 구조로 훈련한 stories15M 모델을 NumPy 압축 형식으로 변환하여 사용함.
 
구조
- Llama 3 모델 구조는 42dot LLM과 동일함.
 - 모델 매개변수:
- 
dim: 288 - 
n_layers: 6 - 
n_heads: 6 - 
vocab_size: 32000 - 
max_seq_len: 256 - 
max_new_tokens: 50 
 - 
 
RoPE #1
- RoPE 임베딩을 위해 cos와 sin을 미리 계산함.
 - 이 값들은 
Q와K에 사용됨. - 계산 결과는 
np.outer로 곱해지고, cos와 sin이 계산됨. 
RMSNorm
- RMSNorm은 전통적인 Mini Batch나 Layer 통계 대신 활성화 값을 Root Mean Square로 정규화함.
 - 일관된 활성화 스케일링을 제공함.
 
QKV
- QKV 계산은 GPT에서 하나의 가중치를 matmul한 후 분할하는 방식과 다르게, Llama는 QKV 각각에 대한 가중치를 가짐.
 - Multi-Head Attention을 위해 각 값을 재구성함.
 
RoPE #2
- RoPE는 절대적 및 상대적 위치 인코딩 특성을 모두 가짐.
 - Q와 K에만 적용되며, 입력을 나누고 cos와 sin으로 곱한 후 결과를 더하고 빼서 다시 재구성함.
 
KV 캐시
- GPT 스타일 생성 모델은 Masked Attention을 사용하여 KV 캐시가 가능함.
 - 이전 결과는 항상 동일하므로, K와 V를 캐시하고 Q는 마지막 값만 계산함.
 
GQA(Grouped-Query Attention)
- GQA는 Llama 2에서 도입된 기술로, 메모리 절약과 성능 향상을 제공함.
 - Llama 3에서는 8B 이상의 모든 모델에 GQA가 적용됨.
 
Scaled Dot-Product Attention
- Multi-Head Attention으로 각각의 Attention을 계산함.
 - 결과는 softmax와 matmul로 얻어짐.
 
Feed Forward
- Llama 모델의 Feed Forward는 3개의 선형 계층을 사용하며, bias가 없음.
 - swish 값을 생성하고, 
x_V와 곱한 후 다시 다운스케일링함. 
SwiGLU
- SwiGLU는 여러 피드 포워드 계층의 독특한 조합으로 모델 성능을 향상시킴.
 
Linear
- 최종 출력은 마지막 logit만 matmul로 계산하여 속도를 높임.
 
생성
- 추출된 logit을 사용하여 토큰을 하나씩 생성함.
 - Prefill Phase와 Decode Phase로 나뉨.
 - Prefill Phase에서는 모든 입력을 전달하고, Decode Phase에서는 마지막 토큰 ID만 전달하여 결과를 얻음.
 
예제
- 다음과 같이 실행할 수 있음:
$ python llama3.py "I have a dream" 
GitHub
- 전체 소스 코드는 likejazz/llama3.np에서 확인 가능함.
 
참고 문헌
- Exploring and Building the Llama 3 Architecture
 - Rotation Matrix
 - Mastering LLM Techniques: Inference Optimization
 - arXiv:2305.13245
 
GN⁺의 의견
- Llama 3 모델의 구조와 성능: Llama 3 모델은 기존 Llama 2 모델의 구조를 유지하면서도 성능을 크게 향상시킴. 이는 모델의 확장성과 효율성을 동시에 고려한 결과임.
 - NumPy로 구현한 이유: NumPy를 사용하여 모델을 구현함으로써, 모델의 구조와 동작을 더 직관적으로 이해할 수 있음. 이는 학습자나 연구자에게 큰 도움이 됨.
 - GQA의 도입: GQA는 메모리 절약과 성능 향상을 동시에 제공하는 기술로, Llama 3에서 모든 모델에 적용됨으로써 모델의 효율성을 극대화함.
 - KV 캐시의 중요성: KV 캐시는 GPT 스타일 생성 모델에서 중요한 역할을 하며, 이를 통해 모델의 계산 효율성을 크게 높일 수 있음.
 - 실제 사용 사례: 예제 코드를 통해 모델을 실제로 실행해볼 수 있으며, 이는 모델의 성능을 직접 확인할 수 있는 좋은 기회임.