FFT의 반격: Self-Attention의 효율적 대안
(arxiv.org)- 긴 컨텍스트 Transformer에서 self-attention 비용이 병목이 되는 상황에서, SPECTRE는 FFT 기반 토큰 믹서로 레이어당 복잡도를 O(L²)에서 O(L log L) 로 낮춤
- 각 attention head는 빠른 real FFT, 콘텐츠 적응형 spectral gate, inverse FFT 조합으로 바뀌며 기존 Transformer 구조는 유지됨
- 자동회귀 생성에서는 Prefix-FFT cache로 매 단계 FFT 재계산 부담을 줄이고, 선택형 wavelet 모듈로 로컬 특징 손실을 보완할 수 있음
- Llama-3.2-1B 백본에서 SDPA, FlashAttention-2, SPECTRE를 비교했으며, NVIDIA A100-80GB에서 512~128k 토큰 처리량과 지연시간을 측정함
- SPECTRE는 PG-19와 ImageNet-1k에서 기준 성능과 같거나 더 높았고, 6% 미만 파라미터 추가로 일반 GPU의 긴 컨텍스트 처리를 목표로 함
Self-attention의 2차 비용을 FFT로 줄이는 방식
- 긴 컨텍스트 Transformer는 multi-turn dialogue, 책 길이 요약, 고해상도 비전처럼 수만 토큰을 다루는 작업에서 필요함
- 기존 self-attention은 O(n²d) 비용 때문에 컨텍스트가 길어질수록 추론 지연시간과 메모리 사용이 커짐
- SPECTRE는 self-attention 레이어를 주파수 영역 토큰 믹서로 바꾸는 drop-in 대체 방식임
- 토큰을 orthonormal Fourier basis로 투영함
- 콘텐츠 적응형 대각 gate와 선택형 low-rank gate를 적용함
- inverse transform으로 다시 토큰 공간으로 되돌림
- 주변 네트워크 아키텍처를 바꾸지 않으면서 레이어당 복잡도를 O(n log n) 으로 낮추는 것이 핵심임
토큰 믹서 구성과 생성 지원
- SPECTRE의 attention head 대체 구성은 fast real FFT, spectral gate, inverse FFT임
- spectral gating은 n/2 + 1개의 주파수 계수에서 동작해 계산과 메모리 사용을 줄이면서 표현력을 유지하도록 설계됨
- Prefix-FFT cache는 표준 KV-cache와 비슷한 역할로 스트리밍 디코딩을 지원함
- 자동회귀 생성에서 매 time step마다 FFT를 다시 계산해야 하는 기존 spectral mixer의 약점을 줄임
- 고정 메모리 예산 안에서 효율적인 생성을 가능하게 하는 구조임
- 선택형 Wavelet Refinement Module은 순수 spectral 방식에서 손실될 수 있는 로컬 디테일을 보완하며, 계산 오버헤드는 작음
기존 Transformer에 적용하는 방법
- SPECTRE는 multi-head attention 레이어를 직접 대체할 수 있어 별도 아키텍처 개편을 요구하지 않음
- 기존 사전학습 모델은 SPECTRE 레이어로 fine-tuning할 수 있음
- 업데이트 대상은 새로 도입된 파라미터임
- 추가 파라미터는 전체 가중치의 6% 미만임
- specialized optimization이나 비표준 아키텍처가 필요한 접근과 달리, 주변 Transformer 구조를 유지함
Llama-3.2-1B 기반 실험
- 동일한 Llama-3.2-1B 백본에 세 가지 attention kernel을 적용해 비교함
- standard softmax-dot-product attention(SDPA)
- FlashAttention-2
- SPECTRE mixer
- 측정 환경은 NVIDIA A100-80GB이며, 시퀀스 길이는 L ∈ {512, 1k, 4k, 8k, 32k, 128k}임
- 지표는 tokens-per-second 처리량과 single-batch latency임
- 처리량은 높을수록 좋음
- latency는 낮을수록 좋음
- SPECTRE는 backbone 정확도를 유지하면서 거의 O(n log n) 에 가까운 실행시간을 보임
- 32k 토큰까지 실행시간이 거의 평평하게 유지됨
- abstract 기준으로 128k-token context에서 FlashAttention-2보다 최대 7× 빠름
- 본문 contribution 목록 기준으로는 32k 토큰에서 FlashAttention-2보다 최대 7× 빠른 추론을 보임
벤치마크 결과와 실용 범위
- SPECTRE는 PG-19 언어 모델링과 ImageNet-1k 분류에서 baseline 성능과 같거나 더 높은 결과를 보임
- 긴 컨텍스트 처리에서 self-attention의 2차 비용을 피하면서도 글로벌 컨텍스트 믹싱을 유지함
- sparse pattern, kernel approximation, low-rank structure 기반 attention 가속 방식은 exactness 희생, 비표준 최적화, streaming generation 미지원 같은 한계가 있을 수 있음
- SPECTRE는 FFT가 circular convolution을 대각화해 global mixing을 element-wise product로 바꾸는 주파수 영역 접근을 사용함
- 추가 파라미터를 6% 미만으로 제한해, specialized hardware 없이 commodity GPU에서 hundred-kilotoken context 처리를 목표로 함
댓글과 토론
Hacker News 의견들
-
기본적으로 합성곱 정리를 활용하는 방식임: 원래 공간에서 비싼 합성곱이 상호 공간에서는 단순한 곱셈이 되고, 그 반대도 성립함
데이터에 합성곱 연산이 있으면 켤레 영역으로 변환해 곱셈으로 바꾸면 됨
달리 말하면, 데이터에 자연스러운 영역에서 작업하라는 뜻
https://en.wikipedia.org/wiki/Convolution_theorem- 이렇게 표현하니 아주 좋지만, LLM에서 구조화된 어텐션 공간이 주파수 영역이라는 점은 내게는 전혀 자명하지 않았음
- 기본적인 수학적 공간 변환 샌드위치임: 1) 데이터를 다른 공간으로 바꾸고 2) 그 공간에서 연산한 뒤 3) 원래 공간으로 되돌림
최적화하려면 각 단계를 최적화하고, 가능한 한 가장 효율적인 공간에서 많이 작업하면 됨 - “데이터에 자연스러운 영역에서 작업하라”는 말에서, 왜 곱셈이 합성곱보다 어떤 영역에 더 자연스럽다고 봐야 하는지 모르겠음
단지 계산이 더 쉬운 것과는 다른 이야기 아닌가? - 상호 공간은 항상 주파수 = 1/시간처럼 그냥 1/공간 형태인가?
- 맞지만 절약은 이론적인 면이 큼. O(n²) 연산을 O(nlog n)으로 바꾸는 건 좋아 보이지만, 평균 n이 3이라는 걸 깨닫기 전까지의 이야기임
게다가 계산에 복소수를 써야 하고, 수치적으로도 덜 안정적임. 내가 아는 한 FFT는 일반적인 합성곱에서는 이득이 아님
자기 어텐션이나 이 논문의 용도에서는 n이 훨씬 클 수도 있음. 논문은 안 읽었음. 그래도 복소수 문제는 남음
-
Google은 2022년에 FNet: Mixing Tokens with Fourier Transforms로 이 아이디어를 도입했음
이후 대부분의 상황에서 TPU의 행렬 곱셈 성능이 FFT보다 더 빠르다는 걸 알게 됨
https://arxiv.org/abs/2105.03824- 이 논문에서도 인용됨:
“전반적으로 FNet, Performer, 희소 트랜스포머 같은 접근은 고정 또는 근사 토큰 혼합으로 계산 부담을 줄일 수 있음을 보여주지만, 우리의 적응형 스펙트럼 필터링 전략은 FFT의 효율성과 학습 가능하고 입력 의존적인 스펙트럼 필터를 독특하게 결합한다. 이는 복잡한 시퀀스 모델링 작업에 중요한 확장성과 적응성의 강력한 조합을 제공한다.”
그 뒤에 비교 섹션도 있음 - 특수 하드웨어가 더 낫다는 비교는 좀 이상해 보임
그런데 DSP에는 FFT를 돕는 전용 하드웨어가 있나? 진짜로 궁금해서 묻는 것임. 써본 적은 없지만 어렴풋이 도움이 될 것 같음 - GPU는 TPU보다 10% 개선을 보였음
“TPU는 푸리에 변환에서 너무 비효율적이라 연구자들은 4096개 미만 시퀀스에서는 FFT 알고리즘을 쓰지 않고, 미리 계산한 DFT 행렬을 사용하는 이차 스케일링 푸리에 변환 구현을 선택했다.”
“Nvidia Quadro P6000 GPU에서는 FNet 아키텍처에서 푸리에 변환이 추론 시간의 최대 30%를 차지했다.”
이 회사는 2021년에 Google이 TPU에 자기들의 광 칩을 쓰면 추론 시간을 40% 줄일 수 있다고 주장했음. FFTNet이 더 많은 일을 맡으면 더 줄어들 수도 있음
https://scribe.rip/optalysys/attention-fourier-transforms-a-... - 문맥 창의 토큰 수를 늘릴수록 FFT의 스케일링이 더 좋아질 것 같음. Google 모델들이 문맥 크기에서 경쟁자들을 앞서는 점이 흥미로움
- FFT보다 빠른 것뿐만 아니라, TPU의 FFT 지원은 항상 최선 노력 수준이었음. 마지막으로 시도했을 때는 심각한 정밀도 문제가 있었음
- 이 논문에서도 인용됨:
-
푸리에 변환은 “토큰” 차원을 따라 적용됨. 하지만 많은 응용에서는 이 차원이 의미를 갖지 않음. 그래서 트랜스포머가 순열 불변 데이터를 처리하는 데 좋은 선택지가 됨
덜 알려진 유한군 위의 푸리에 변환을 사용한 추가 실험을 보고 싶음. 이는 순열 불변이면서도 표준 푸리에 변환과 많은 성질을 공유함
또 이것이 LLM의 다음 큰 흐름이 된다면, vLLM이나 llama.cpp 같은 추론 엔진이 얼마나 쉽게 통합할 수 있을지도 궁금함
https://en.wikipedia.org/wiki/Fourier_transform_on_finite_gr...- 이 분야 전문가는 아니지만, 대부분의 모델에서는 토큰이 위치 의존 정보와 함께 변환되지 않나?
llama는 입력 내 위치에 따라 벡터에 회전을 적용하는 것으로 알고 있음 - 이 경우의 유한군은 무엇인가?
- 이 분야 전문가는 아니지만, 대부분의 모델에서는 토큰이 위치 의존 정보와 함께 변환되지 않나?
-
수학은 완전히 머리 위로 지나가고, 수식 주변 설명도 겨우 이해하는 수준임. 누가 쉬운 말로 이것이 어떻게 어텐션 메커니즘과 동등한지 설명해 줄 수 있나?
여기서 말하는 주파수는 무엇이고, 토큰 간 위치 관계는 어떻게 인코딩하나?- 푸리에 변환은 가역 연산자임. 즉 함수에 작용하며, 행렬의 경우 함수와 연산자 모두 행렬로 표현될 수 있음. 이를 우리가 주파수 공간이라고 부르는 곳으로 변환함
신호 분석이나 이미지에서는 가장 직관적임: https://homepages.inf.ed.ac.uk/rbf/HIPR2/fourier.htm
주파수 공간은 본질적으로 복소수로 표현되는 “복소” 공간임. 주파수는 문제를 전역적으로 바라본다는 장점이 있음
이 메커니즘은 어텐션 메커니즘과 동등하지 않으며, 분명한 절충이 있음. 다만 어텐션이 포착하는 중요한 관계 중 상당수를 포착할 가능성은 있음
modReLU에 대해서는 당장 좋은 직관이 없지만, 주파수를 수정하면서도 역푸리에 변환을 보존하기 때문에 중요한 것으로 보임 - 실제 메커니즘 자체는 꽤 단순함. 입력 임베딩에 FFT를 적용하고, 입력 임베딩에서 MLP로 얻은 가중치와 원소별 곱을 한 뒤, 상수지만 학습 가능한 편향을 더하고, 활성화 함수를 거친 다음 마지막으로 역 FFT를 적용함
여기서 “주파수”는 아마 꽤 추상적인 것일 가능성이 큼. FFT는 명확한 주파수 해석이 없는 방식으로도 자주 쓰임. 합성곱 정리 같은 편리한 수학적 성질 때문에 쓰는 경우가 많음
정말 잘 작동한다면 꽤 놀랍고, 매우 우아함 - 전문가는 전혀 아니지만 직관을 조금 보태자면, 자기 어텐션은 결국 매개변수화된 토큰 혼합기임
즉 출력의 각 벡터는 해당 입력 벡터가 다른 모든 입력 벡터들의 어떤 함수에 의해 변환된 것에 의존함
https://medium.com/optalysys/attention-fourier-transforms-a-...
개념적으로 이것이 약간 단순화된 합성곱과 어떻게 비슷한지 볼 수 있음: https://openreview.net/pdf?id=8l5GjEqGiRG
합성곱은 어떤 방식으로든 전역 상태를 고려하고 싶을 때 자주 사용됨
- 푸리에 변환은 가역 연산자임. 즉 함수에 작용하며, 행렬의 경우 함수와 연산자 모두 행렬로 표현될 수 있음. 이를 우리가 주파수 공간이라고 부르는 곳으로 변환함
-
이 프레임워크에 인과적 마스킹을 넣으려면 n개의 서로 다른 FFT를 해야 할 것 같은데, 위치 임베딩에 대한 언급도 없음
그래서 비교 대상 자기 어텐션 구현은 비인과적 NoPE인 듯하고, 그렇다면 기준선을 일부러 약하게 잡은 사례라 그리 인상적이지 않을 수도 있음
결과가 최신 수준에 가까웠다면 저자가 아마 언급했을 것 같음- Long Range Arena(LRA) 벤치마크에서는 자기 모델이 모든 범주에서 이긴다고 보여주긴 함. 패배한 범주나 더 나은 모델을 제외하지 않았기를 바람
-
관련 참고문헌으로 보임: https://arxiv.org/abs/2111.13587
Adaptive Fourier Neural Operators: Efficient Token Mixers for Transformers
John Guibas, Morteza Mardani, Zongyi Li, Andrew Tao, Anima Anandkumar, Bryan Catanzaro -
여기서 주파수 영역으로 보는 것이 왜 도움이 되는지 직관이 있는지 궁금함
직류 성분은 이해가 되지만, 입력 데이터가 다른 주파수들이 의미를 가질 만큼 충분히 주기적일 거라고는 기대하지 않음 -
몇 년 전에 이미 O(n log n) 전체 문맥 혼합을 보여준 Hyena Operator 선행 연구가 언급되지 않은 것 같음
https://arxiv.org/abs/2302.10866- Hyena는 같은 연구실의 Albert Gu가 했던 선행 작업에서 나왔음
https://arxiv.org/abs/2111.00396
- Hyena는 같은 연구실의 Albert Gu가 했던 선행 작업에서 나왔음
-
빅오 표기법은 어느 정도 감을 잡지만, 컴퓨터공학이나 전기공학과 관련된 대부분의 내용처럼 이것도 머리 위로 지나감
수학을 정말 못하는 입장에서, 이런 내용을 이해하거나 적어도 배워서 공학 학위와 면허까지 딸 수 있는 사람들이 부러움
FFT에 대해 아는 건 신호를 바꾸고, 어떤 종류의 신호 처리에 쓰이며, 예전에 핵폭발 탐지의 핵심이었다고 들었다는 정도임- 푸리에 변환에 대한 괜찮은 직관은, 손으로 푸리에 변환을 유도하거나 FFT 알고리즘을 직접 짤 수 없더라도 매우 유용한 도구임
기본 아이디어는 이렇다: 거의 모든 유용한 신호는 서로 다른 주파수와 위상을 가진 사인파들의 합으로 표현될 수 있음. 예를 들어 전기 신호나 음파는 x축이 시간인 1차원 신호임. 보기에는 다루기 어려운 복잡한 꼬불꼬불한 선일 수 있음
푸리에 변환을 쓰면 시간 기반 신호의 개별 주파수를 분리할 수 있음. 그런 다음 특정 주파수를 원하는 방식으로 수정할 수 있음. 예컨대 신호에 무작위적인 뾰족한 잡음이 많으면 이는 높은 주파수로 나타남. 정리하려면 푸리에 변환을 하고, 특정 임계값보다 높은 주파수의 데이터를 버린 뒤, 남은 데이터에 역푸리에 변환을 적용해 원래 신호의 더 매끈한 버전으로 되돌리면 됨. 이것을 저역 통과 필터라고 하며, 원래 신호의 이동 평균을 취하는 것과 거의 비슷함
재미있는 부분은 이것을 꽤 직관적으로 더 높은 차원으로 확장할 수 있다는 점임. x축과 y축이 모두 공간인 2차원 신호는 이미지임. JPEG 압축은 이 개념에 기반함. 이미지를 더 작게 저장하기 위해 고주파 신호를 제거하고, 그 대가로 미세한 디테일을 잃거나 너무 많이 버리면 링 모양 아티팩트가 생김. 여기에 시간이라는 세 번째 차원을 더하면 동영상이 되고, 계속 확장 가능함
이 모든 것이 시각적으로 이해하기 좋아서, 수학을 전부 깊이 알지 않아도 좋은 직관을 얻을 수 있음. 시각화와 인터랙티브 예제가 많은 좋은 페이지: https://www.jezzamon.com/fourier/index.html
3Blue1Brown 영상도 설명을 잘함: https://youtu.be/spUNpyF58BY?si=dz0z-s8NftW3Htun - 간단히 말하면, 마이크로 측정한 오디오 신호처럼 1차원 시간 영역 신호가 있다고 해보자. 마이크가 고정되어 있다면 특정 지점에서 시간에 따른 공기의 변위를 측정하는 것임
FFT가 이산 버전인 푸리에 변환은 그 1차원 시간 영역 신호를 주파수 대비 크기와 위상 성분으로 분해함
주파수는 기본적으로 음높이임. 순수 사인파나 순수 톤은 예전 밤늦게 TV 방송 종료 때 들리던 소리와 비슷한데, 이 경우 대부분은 0이고 해당 톤의 주파수 위치에 하나의 “스파이크”가 생김. 신호 진폭이 클수록 스파이크의 크기도 커짐. 음높이, 즉 주파수가 올라가거나 내려가면 이 스파이크 위치가 가로축을 따라 위아래로 움직임
위상은 기본적으로 신호의 시간 오프셋임. 어떤 식으로든 지연된 톤은 다른 위상으로 나타남. 다만 이는 절대 측정이 아니라 상대 측정임. 단위가 라디안, 즉 각도라서 원을 한 바퀴 돌면 다시 “리셋”되므로, 신호가 1초 밀렸는지 2초 밀렸는지 같은 것은 알 수 없음
그래서 하나의 신호, 즉 시간 대비 진폭에서 실제로는 주파수 대비 크기와 위상이라는 두 가지 정보를 얻음
허수나 복소변수를 이해한다면, 이 두 신호는 사실 복소 함수인 FFT 출력의 크기와 편각일 뿐임
- 푸리에 변환에 대한 괜찮은 직관은, 손으로 푸리에 변환을 유도하거나 FFT 알고리즘을 직접 짤 수 없더라도 매우 유용한 도구임
-
텔레메트리의 시대에, 클라우드 텔레메트리에 FFT를 적용해 주기적 이상과 준안정 시스템을 사고가 터진 뒤가 아니라 전에 찾아내지 않는 건 큰 기회를 놓치는 것 같음
불행히도 이건 내가 알아차릴 수 있는 수준에는 있지만, 구현할 기술 수준에는 없고 이미 일정도 꽉 차 있음
“SLA는 서비스 배포 후 23~25분 뒤에 가장 위반되기 쉽다. 흠, 왜 그럴까… 아 안 돼.”- “죄송하지만 Dave, 당신의 애플리케이션은 배포할 수 없습니다”
농담은 제쳐두고, 이게 정말 돈이 될 수 있는 영역은 트래픽 주기를 예측해서 서버 인스턴스를 올리고 내리며 비용을 절감하는 것임
이런 일은 개인 시간으로 하면 회사가 절대 승인해주지 않겠지만, 기성 제품으로 포장하면 회사가 바로 살 종류의 작업임
- “죄송하지만 Dave, 당신의 애플리케이션은 배포할 수 없습니다”