20P by neo 6달전 | favorite | 댓글 1개
  • Llama 3 모델의 실제 동작 가능한 구현을 통해 정확한 구조 이해 하기

개요

  • Meta에서 공개한 Llama 3 모델이 주목받고 있음.
  • 24K GPUs, 15T 훈련 데이터, 10M 명령 데이터, 1.3M GPU 시간 등 압도적인 스케일과 성능을 자랑함.
  • 모델 구조는 크게 변하지 않았음. Llama 3는 GQA를 사용하지만, 이는 Llama 2 70B에서도 이미 구현된 바 있음.
  • NumPy만을 사용하여 모델 구조를 직관적으로 이해할 수 있도록 구현함.
  • Andrej Karpathy가 Llama 2 구조로 훈련한 stories15M 모델을 NumPy 압축 형식으로 변환하여 사용함.

구조

  • Llama 3 모델 구조는 42dot LLM과 동일함.
  • 모델 매개변수:
    • dim: 288
    • n_layers: 6
    • n_heads: 6
    • vocab_size: 32000
    • max_seq_len: 256
    • max_new_tokens: 50

RoPE #1

  • RoPE 임베딩을 위해 cos와 sin을 미리 계산함.
  • 이 값들은 QK에 사용됨.
  • 계산 결과는 np.outer로 곱해지고, cos와 sin이 계산됨.

RMSNorm

  • RMSNorm은 전통적인 Mini Batch나 Layer 통계 대신 활성화 값을 Root Mean Square로 정규화함.
  • 일관된 활성화 스케일링을 제공함.

QKV

  • QKV 계산은 GPT에서 하나의 가중치를 matmul한 후 분할하는 방식과 다르게, Llama는 QKV 각각에 대한 가중치를 가짐.
  • Multi-Head Attention을 위해 각 값을 재구성함.

RoPE #2

  • RoPE는 절대적 및 상대적 위치 인코딩 특성을 모두 가짐.
  • Q와 K에만 적용되며, 입력을 나누고 cos와 sin으로 곱한 후 결과를 더하고 빼서 다시 재구성함.

KV 캐시

  • GPT 스타일 생성 모델은 Masked Attention을 사용하여 KV 캐시가 가능함.
  • 이전 결과는 항상 동일하므로, K와 V를 캐시하고 Q는 마지막 값만 계산함.

GQA(Grouped-Query Attention)

  • GQA는 Llama 2에서 도입된 기술로, 메모리 절약과 성능 향상을 제공함.
  • Llama 3에서는 8B 이상의 모든 모델에 GQA가 적용됨.

Scaled Dot-Product Attention

  • Multi-Head Attention으로 각각의 Attention을 계산함.
  • 결과는 softmax와 matmul로 얻어짐.

Feed Forward

  • Llama 모델의 Feed Forward는 3개의 선형 계층을 사용하며, bias가 없음.
  • swish 값을 생성하고, x_V와 곱한 후 다시 다운스케일링함.

SwiGLU

  • SwiGLU는 여러 피드 포워드 계층의 독특한 조합으로 모델 성능을 향상시킴.

Linear

  • 최종 출력은 마지막 logit만 matmul로 계산하여 속도를 높임.

생성

  • 추출된 logit을 사용하여 토큰을 하나씩 생성함.
  • Prefill Phase와 Decode Phase로 나뉨.
  • Prefill Phase에서는 모든 입력을 전달하고, Decode Phase에서는 마지막 토큰 ID만 전달하여 결과를 얻음.

예제

  • 다음과 같이 실행할 수 있음:
    $ python llama3.py "I have a dream"  
    

GitHub

참고 문헌

  1. Exploring and Building the Llama 3 Architecture
  2. Rotation Matrix
  3. Mastering LLM Techniques: Inference Optimization
  4. arXiv:2305.13245

GN⁺의 의견

  • Llama 3 모델의 구조와 성능: Llama 3 모델은 기존 Llama 2 모델의 구조를 유지하면서도 성능을 크게 향상시킴. 이는 모델의 확장성과 효율성을 동시에 고려한 결과임.
  • NumPy로 구현한 이유: NumPy를 사용하여 모델을 구현함으로써, 모델의 구조와 동작을 더 직관적으로 이해할 수 있음. 이는 학습자나 연구자에게 큰 도움이 됨.
  • GQA의 도입: GQA는 메모리 절약과 성능 향상을 동시에 제공하는 기술로, Llama 3에서 모든 모델에 적용됨으로써 모델의 효율성을 극대화함.
  • KV 캐시의 중요성: KV 캐시는 GPT 스타일 생성 모델에서 중요한 역할을 하며, 이를 통해 모델의 계산 효율성을 크게 높일 수 있음.
  • 실제 사용 사례: 예제 코드를 통해 모델을 실제로 실행해볼 수 있으며, 이는 모델의 성능을 직접 확인할 수 있는 좋은 기회임.

해커뉴스에 올라온 것은 영문인데, 원저자인 Likejazz 님이 한국어로 작성해두신 링크로 변경했습니다.