1P by neo 5달전 | favorite | 댓글 1개

FlashAttention-3: 비동기 및 저정밀도로 빠르고 정확한 Attention

  • Attention의 중요성

    • Attention은 Transformer 구조의 핵심 계층으로, 대형 언어 모델과 긴 문맥 응용 프로그램에서 병목 현상을 일으킴.
    • FlashAttention과 FlashAttention-2는 GPU에서 메모리 읽기/쓰기를 최소화하여 Attention을 가속화하는 접근 방식을 개척함.
    • 이로 인해 LLM의 문맥 길이가 크게 증가함.
  • FlashAttention-3의 주요 기술

    • 비동기성 활용: Tensor Cores와 TMA의 비동기성을 활용하여 전체 계산과 데이터 이동을 겹침.
    • 블록 단위 연산: 블록 단위의 행렬 곱셈과 softmax 연산을 교차 수행.
    • 저정밀도 처리: FP8 저정밀도 지원을 활용하여 성능을 향상시킴.
  • FlashAttention-3의 성능 향상

    • GPU 활용 효율성: H100 GPU의 최대 성능을 75%까지 활용하여 이전 버전보다 1.5-2배 빠름.
    • 저정밀도 성능: FP8을 사용하여 처리 속도를 높이고 메모리 사용량을 줄임.
    • 긴 문맥 처리: Attention 메커니즘을 가속화하여 더 긴 텍스트를 효율적으로 처리 가능.
  • FlashAttention 요약

    • FlashAttention은 Attention 계산을 재정렬하고 타일링과 재계산을 활용하여 속도를 크게 높이고 메모리 사용량을 줄임.
    • 타일링을 통해 입력 블록을 로드하고, 해당 블록에 대해 Attention을 수행한 후 출력을 업데이트함.
    • 중간 Attention 행렬을 메모리에 쓰지 않음으로써 메모리 읽기/쓰기 양을 줄임.
  • Hopper GPU의 새로운 하드웨어 기능

    • WGMMA: 새로운 Tensor Cores를 활용하여 높은 처리량을 제공.
    • TMA: 글로벌 메모리와 공유 메모리 간 데이터 전송을 가속화하는 하드웨어 유닛.
    • FP8 저정밀도: FP8을 사용하여 Tensor Core 처리량을 두 배로 늘림.
  • 비동기성: GEMM과 Softmax 겹치기

    • 겹치기의 필요성: GEMM과 softmax를 병렬로 수행하여 성능을 극대화함.
    • 핑퐁 스케줄링: 두 워프 그룹이 번갈아 가며 GEMM과 softmax를 수행하여 성능을 향상시킴.
    • 워프 그룹 내 겹치기: 동일한 워프 그룹 내에서 GEMM과 softmax를 병렬로 수행하여 처리량을 증가시킴.
  • 저정밀도: 비일관 처리로 양자화 오류 감소

    • 비일관 처리: Hadamard 변환을 사용하여 양자화 오류를 줄임.
    • 실험 결과: 비일관 처리를 통해 양자화 오류를 2.6배 감소시킴.
  • Attention 벤치마크

    • FP16: FlashAttention-2보다 약 1.6-1.8배 빠름.
    • FP8: 최대 1.2 PFLOPS에 도달.

GN⁺의 정리

  • FlashAttention-3는 GPU의 새로운 하드웨어 기능을 활용하여 Attention 메커니즘의 성능을 크게 향상시킴.
  • 긴 문맥을 효율적으로 처리할 수 있어 대형 언어 모델의 성능을 극대화함.
  • PyTorch와 같은 주요 프레임워크에 통합될 가능성이 높아 향후 AI 연구와 응용에 큰 영향을 미칠 것임.
  • 유사한 기능을 제공하는 프로젝트로는 Triton과 cuDNN이 있음.
Hacker News 의견
  • Tri Dao가 FA3 작업을 2022년 4월부터 시작한 것으로 보임

    • Hopper/H100 발표 후 2년이 지나서야 코드가 공개된 이유는 더 나은 솔루션이 준비되었기 때문일 가능성이 있음
    • 최근 Tri의 연구는 SSM과 Mamba 스타일 아키텍처에 집중되어 있음
    • Flash Attention은 시퀀스 길이에 대해 이차 시간 복잡성을 가지지만, 최신 알고리즘은 이차 이하의 복잡성을 가짐
    • Dao와 Gu는 올해 Mamba/SSM이 Transformer와 같은 하드웨어 가속을 받을 수 있도록 공식화하는 논문을 발표함
  • Flash Attention 알고리즘이 하드웨어에 얼마나 의존적인지 궁금함

    • H100 GPU의 비동기 기능을 활용한다고 언급됨
    • Flash Attention 라이브러리는 CUDA를 필요로 하지만, Metal로 포팅된 것으로 보임
    • 알고리즘이 순수 함수라면 어떤 GPU/ML 프레임워크에서도 구현 가능할 것이라고 상상함
  • 컴파일러가 FlashAttention과 같은 최적화를 스스로 찾을 수 있을지 궁금함

    • TVM과 tinygrad가 그 방향으로 작업 중이지만, 실현 가능성에 대해 의문을 가짐
  • ROCm/AMD MI300x로 포팅을 원하는 사람은 연락을 달라고 함

    • 컴퓨팅 시간을 기부할 의향이 있음
  • TMA (Tensor Memory Accelerator)는 글로벌 메모리와 공유 메모리 간의 데이터 전송을 가속화하는 하드웨어 유닛임

    • 레지스터를 해방시켜 타일 크기와 효율성을 증가시킴
  • FlashAttention-3는 Hopper GPU (예: H100)에 최적화되어 있음

    • 소비자용 GPU (예: 3090, 4090)에서는 어떻게 작동하는지 궁금함
  • 현대 LLM에서 sigmoid와 같은 활성화 함수가 매우 느리다고 언급됨

    • SiLU, Swish, SOLU와 같은 활성화 함수가 많이 사용됨
    • Relu가 성능 저하를 덜 일으킨다면, Relu로 돌아가는 것이 더 나을 수도 있음
  • 가변 마스킹이 없는 경우보다 있는 경우 Flash Attention이 5배 느린 이유가 궁금함

    • 좋은 마스킹 지원의 부족이 최적화를 거의 무효화함
  • FlashAttention이 LLM의 attention 연산을 대체할 수 있는지 궁금함

    • LLM이 FA를 사용하도록 특별히 훈련되어야 하는지 궁금함
    • FA가 GQA (grouped query attention)나 슬라이딩 윈도우 attention과 같은 전략과 어떻게 관련되는지 궁금함
    • llama.cpp가 Flash Attention 지원을 추가했을 때, 단순히 Flash Attention 제공 CUDA 커널을 소비하는 것인지 궁금함
    • FlashAttention과 Triton을 비교하는 것이 무엇을 의미하는지 이해하기 어려움
  • 고가의 하드웨어가 필요함