8.2 FlashAttention-2 - 병렬성 최적화와 작업 분할의 재설계
2025-12-21, G30DR
트랜스포머(Transformer) 아키텍처가 등장한 이래, 모델의 크기와 처리 가능한 문맥(Context)의 길이를 확장하려는 시도는 딥러닝 연구의 핵심적인 흐름이었다. 특히 GPT-4, Claude와 같은 거대 언어 모델(LLM)이 수만 토큰 이상의 긴 문맥을 처리해야 하는 상황에서, 어텐션(Attention) 메커니즘이 가진 O(N^2)의 시간 및 메모리 복잡도는 시스템 확장의 가장 큰 병목으로 작용해 왔다. 이러한 문제를 해결하기 위해 등장한 FlashAttention-1은 GPU 메모리 계층 구조(Memory Hierarchy)의 비대칭성을 활용, 고대역폭 메모리(HBM) 접근을 최소화하는 타일링(Tiling) 기법을 통해 메모리 사용량을 선형(O(N))으로 줄이고 연산 속도를 획기적으로 개선하였다.1
그러나 FlashAttention-1이 ‘메모리 병목(Memory-Bound)’ 문제를 해결했음에도 불구하고, 여전히 GPU가 가진 이론적 최대 연산 성능(Theoretical Max FLOPs)을 온전히 활용하지 못한다는 한계가 발견되었다. 구체적으로 NVIDIA A100 GPU 기준, 최적화된 행렬 곱(GEMM) 연산이 이론 성능의 80-90%를 달성하는 반면, FlashAttention-1은 25-40% 수준의 활용률(Utilization)에 머물렀다.1 이는 어텐션 연산의 병목이 메모리 대역폭에서 연산 파이프라인(Compute Pipeline) 자체의 비효율성으로 이동했음을 시사한다.
FlashAttention-2는 이러한 문제의식에서 출발하여, 알고리즘 설계를 밑바닥부터 재검토한 결과물이다. 이 기술은 단순히 메모리 I/O를 줄이는 것을 넘어, GPU의 스레드 블록(Thread Block)과 워프(Warp) 간의 작업 분할(Work Partitioning)을 혁신하고, 병렬성(Parallelism)의 차원을 확장하며, 비 행렬 곱 연산(Non-matmul FLOPs)을 최소화함으로써 A100 GPU 기준 약 225 TFLOPs/s, 이론 성능의 73%에 달하는 효율을 달성하였다.2 본 장에서는 FlashAttention-2가 어떠한 원리로 이러한 비약적인 성능 향상을 이루어냈는지, 그 기술적 메커니즘과 구현 디테일, 그리고 하드웨어 수준에서의 최적화 전략을 심층적으로 분석한다.
1. 서론: 메모리 병목에서 연산 병목으로의 전환
1.1 FlashAttention-1의 유산과 한계
2022년 발표된 FlashAttention-1은 ’IO 인식(IO-Awareness)’이라는 개념을 도입하여 어텐션 연산의 패러다임을 전환했다. 기존의 어텐션 구현이 QK^T 행렬과 소프트맥스 결과를 HBM에 기록하고 다시 읽어오는 과정에서 막대한 대역폭을 소모했던 것과 달리, FlashAttention-1은 입력 블록을 고속의 SRAM(On-chip Memory)에 올린 뒤 연산과정을 융합(Fusion)하여 중간 결과를 HBM에 쓰지 않는 방식을 채택했다. 이를 통해 2-4배의 속도 향상을 이루어냈으나, 이는 어디까지나 메모리 접근 비용을 줄인 결과였다.
연구진은 FlashAttention-1이 실행되는 동안 GPU의 연산 유닛, 특히 텐서 코어(Tensor Core)가 충분히 포화되지 않는 현상을 관찰했다.2 이는 크게 두 가지 원인에 기인한다. 첫째, 소프트맥스(Softmax) 계산과 같은 비 행렬 곱 연산이 전체 실행 시간의 상당 부분을 차지했다. 둘째, GPU의 수많은 코어(Streaming Multiprocessor, SM)들에 작업을 분배하는 방식이 최적화되지 않아, 긴 시퀀스 길이(Long Sequence Length)나 작은 배치 크기(Small Batch Size) 환경에서 병렬 처리 효율이 급격히 저하되었다. 즉, 하드웨어는 연산을 수행할 준비가 되어 있으나, 소프트웨어 스케줄링과 알고리즘 구조가 데이터를 충분히 빠르게 공급하거나 효율적으로 처리하지 못하는 ‘연산 바운드(Compute-Bound)’ 상황에 직면한 것이다.
1.2 차세대 최적화의 목표
FlashAttention-2의 목표는 명확했다. IO 최적화라는 FlashAttention-1의 유산을 계승하면서, GPU의 연산 처리량(Throughput)을 GEMM 수준으로 끌어올리는 것이다. 이를 위해 연구진은 다음 세 가지 핵심 전략을 수립했다.1
- 비 행렬 곱 연산의 최소화: 텐서 코어에 친화적이지 않은 연산을 알고리즘 수준에서 제거하거나 단순화한다.
- 시퀀스 차원의 병렬화: 기존의 배치 및 헤드 차원 병렬화에 더해 시퀀스 길이 차원까지 병렬화하여 GPU 점유율(Occupancy)을 극대화한다.
- 워프 간 통신 제거: 스레드 블록 내부의 워프들이 공유 메모리를 통해 통신하는 오버헤드를 줄이기 위해 작업 분할 방식을 재설계한다.
graph TD
subgraph "패러다임 전환 (Paradigm Shift)"
A["FlashAttention-1"] -->|"해결 (Solved)"| B["메모리 병목 (Memory-Bound)"]
A -->|"남겨진 문제 (Problem Remained)"| C["연산 병목 (Compute-Bound)<br/>낮은 GPU 점유율 (Low Occupancy)"]
C -->|"진화 (Evolution)"| D["FlashAttention-2"]
end
subgraph "FlashAttention-2의 3대 핵심 전략"
D --> G1["비 행렬 곱 연산 최소화<br/>(Minimize Non-matmul FLOPs)"]
D --> G2["시퀀스 차원 병렬화<br/>(Sequence Parallelism)"]
D --> G3["작업 분할 재설계<br/>(Split-Q: Remove Comm)"]
end
style A fill:#f9f,stroke:#333,stroke-width:2px
style D fill:#bbf,stroke:#333,stroke-width:4px
style C fill:#ffcccc,stroke:#333
2. 하드웨어 비대칭성과 비 행렬 곱 연산(Non-matmul FLOPs)의 최소화
현대 GPU 아키텍처는 행렬 곱셈 연산에 극도로 특화되어 있다. NVIDIA A100 GPU의 경우, FP16/BF16 텐서 코어를 활용한 행렬 곱셈의 이론적 최대 성능은 312 TFLOPs/s에 달한다. 반면, 지수 함수(Exponential)나 역수(Reciprocal) 계산과 같은 비 행렬 곱 FP32 연산 성능은 19.5 TFLOPs/s에 불과하다.5 이는 행렬 곱 연산이 비 행렬 곱 연산보다 약 16배 더 빠르다는 것을 의미한다. 즉, 알고리즘 내에 비 행렬 곱 연산이 조금만 포함되어 있어도 전체 성능에는 치명적인 병목으로 작용할 수 있다.
graph TD
subgraph "하드웨어 비대칭성 (Hardware Asymmetry: A100)"
HW1["텐서 코어 행렬 곱<br>(Tensor Core GEMM)"] -->|"속도 (Speed)"| S1["312 TFLOPs/s"]
HW2["비 행렬 곱 연산<br>(Non-matmul FP32)"] -->|"속도 (Speed)"| S2["19.5 TFLOPs/s"]
S1 --"약 16배 빠름 (16x Faster)"--> S2
end
subgraph "알고리즘 최적화 (Algorithm Optimization)"
OP1["기존 온라인 소프트맥스<br/>(Standard Online Softmax)"] -->|"빈번한 나눗셈 (Frequent Division)"| R1["반복마다 재스케일링<br>(Rescaling every iter)"]
OP2["FlashAttention-2 방식"] -->|"연산 지연 (Lazy Computation)"| R2["최종 단계까지 스케일링 지연<br/>(Delay scaling till end)"]
OP2 -->|"통계량 유지 (Keep Stats)"| R3["logsumexp만 갱신<br/>(Update logsumexp only)"]
end
HW2 -.->|"병목 발생 (Bottleneck)"| OP1
R2 -.->|"병목 해소 (Bottleneck Solved)"| HW1
2.1 온라인 소프트맥스(Online Softmax)의 재구성
FlashAttention-1은 수치적 안정성을 보장하기 위해 온라인 소프트맥스 알고리즘을 사용했다. 이 과정에서 각 블록의 국소적 최대값(Local Max)과 합계(Sum)를 갱신하며, 결과값을 지속적으로 재스케일링(Rescaling)해야 했다. FlashAttention-2는 이 과정에서 발생하는 비 행렬 곱 연산을 줄이기 위해 알고리즘을 수학적으로 재구성했다.
기존 방식은 반복마다 출력 행렬 O를 현재까지의 정규화 인자로 나누어주는 연산이 빈번했다. FlashAttention-2는 최종 단계 전까지는 스케일링을 지연시키거나, 로컬 통계량 L (logsumexp)만을 유지하고 업데이트하는 방식으로 나눗셈 연산의 횟수를 획기적으로 줄였다.1 또한, 역전파(Backward Pass) 과정에서도 \text{logsumexp} 값을 저장해 둠으로써 max와 sum을 다시 계산하는 비용을 절감했다.3
2.2 인과적 마스킹(Causal Masking) 및 바운드 체크 최적화
트랜스포머 모델, 특히 GPT 계열의 디코더 모델에서는 미래의 토큰을 참조하지 못하도록 하는 인과적 마스킹(Causal Masking)이 필수적이다. 기존에는 이를 구현하기 위해 모든 위치에 대해 마스크 연산을 수행하거나 조건문을 걸어야 했다. FlashAttention-2는 이러한 조건부 연산이 텐서 코어 파이프라인을 방해하지 않도록, 블록 단위로 마스킹 여부를 판단하여 처리가 필요한 블록과 그렇지 않은 블록을 조기에 분기(Early Exit)하거나, 마스킹 연산을 행렬 곱 연산과 융합하는 형태로 최적화했다.5 이러한 미세한 튜닝들은 전체 FLOPs 중 행렬 곱이 차지하는 비중을 높여 GPU가 텐서 코어를 최대한 가동할 수 있는 환경을 조성한다.
3. 병렬성(Parallelism)의 재해석: 시퀀스 길이 차원 활용
GPU 프로그래밍에서 병렬성은 성능을 결정짓는 가장 중요한 요소 중 하나다. GPU는 수천 개의 스레드를 동시에 실행하여 높은 처리량을 얻는데, 이를 위해서는 충분한 수의 독립적인 작업(Task)이 공급되어야 한다.
graph TD
subgraph "상황: 긴 문맥 & 작은 배치 (Long Context, Small Batch)"
Context["Batch=1, Heads=32, A100 GPU (108 SMs)"]
end
subgraph "FlashAttention-1: 배치/헤드 병렬화"
FA1_Logic["작업 단위: (Batch, Head)"]
FA1_Blocks["생성된 블록 수: 32개"]
FA1_Result["GPU 활용 (GPU Utilization)"]
Context --> FA1_Logic --> FA1_Blocks
FA1_Blocks -->|"108개 중 32개만 사용"| FA1_Result
FA1_Result -->|"76 SMs 유휴 상태 (Idle)"| LowPerf["성능 저하 (Low Performance)"]
end
subgraph "FlashAttention-2: 시퀀스 병렬화 추가"
FA2_Logic["작업 단위: (Batch, Head, Seq-Chunk)"]
FA2_Split["시퀀스 분할 (Split Sequence)"]
FA2_Blocks["생성된 블록 수: 32 * N개"]
FA2_Result["GPU 활용 (GPU Utilization)"]
Context --> FA2_Logic --> FA2_Split --> FA2_Blocks
FA2_Blocks -->|"108개 SM 모두 사용 가능"| FA2_Result
FA2_Result -->|"높은 점유율 (High Occupancy)"| HighPerf["성능 극대화 (Max Performance)"]
end
3.1 기존 병렬화의 한계: 배치(Batch)와 헤드(Head)
FlashAttention-1은 병렬화의 단위를 **배치 크기(Batch Size)**와 **헤드 수(Number of Heads)**로 설정했다. 즉, (배치 인덱스, 헤드 인덱스) 쌍마다 하나의 스레드 블록(Thread Block)이 할당되어 해당 어텐션 연산을 수행하는 구조였다.
이 방식은 배치 크기가 크거나 헤드 수가 많을 때는 효율적이다. 예를 들어, 배치 크기가 64이고 헤드 수가 12라면 총 768개의 스레드 블록이 생성된다. A100 GPU는 108개의 스트리밍 멀티프로세서(SM)를 가지고 있으므로, 768개의 블록은 108개의 SM을 가득 채우고도 남는다.
그러나 최근 LLM의 트렌드는 긴 문맥(Long Context)을 처리하는 것이다. 문맥 길이가 길어지면 메모리 제약으로 인해 배치 크기를 1 또는 2와 같이 매우 작게 줄일 수밖에 없다. 또한, 멀티 쿼리 어텐션(Multi-Query Attention, MQA)이나 그룹 쿼리 어텐션(GQA)과 같은 기법은 키(Key)와 값(Value)의 헤드 수를 줄이는 방향으로 발전하고 있다.
이 경우, 생성되는 스레드 블록의 수가 GPU의 SM 수보다 적어지는 상황이 발생한다. 예를 들어 배치 크기가 1이고 헤드 수가 32라면, 총 32개의 스레드 블록만이 생성된다. 108개의 SM을 가진 A100 GPU에서 32개의 블록만 실행된다면, 나머지 76개의 SM은 아무런 일도 하지 않고 유휴 상태(Idle)로 남게 된다. 이는 심각한 자원 낭비이며 성능 저하의 주범이다.3
3.2 시퀀스 길이 차원 병렬화(Sequence Length Parallelism)
FlashAttention-2는 이 문제를 해결하기 위해 시퀀스 길이 차원을 병렬화의 축으로 추가했다. 이는 Phil Tillet이 Triton 구현체에서 처음 제안하고 구현한 아이디어를 채용한 것이다.4
- Forward Pass: 쿼리(Query) 시퀀스 N을 여러 개의 블록(크기 B_r)으로 나눈다. 이제 스레드 블록은 (배치, 헤드, 시퀀스 청크)의 조합으로 정의된다. 예를 들어 시퀀스 길이가 8192이고 B_r=128이라면, 시퀀스 차원에서만 64개의 분할이 발생한다. 따라서 배치 크기가 1, 헤드 수가 1이라 하더라도 64개의 스레드 블록이 생성되어 GPU의 활용도를 높일 수 있다.
- Backward Pass: 역전파 시에는 쿼리, 키, 값에 대한 기울기(dQ, dK, dV)를 계산해야 한다. 특히 dQ를 계산할 때 시퀀스 차원의 병렬화를 적용하면, 서로 다른 스레드 블록이 동일한 dQ 메모리 영역에 값을 더해야 하는 상황(Race Condition)이 발생할 수 있다. FlashAttention-2는 이를 해결하기 위해 CUDA의 원자적 덧셈(Atomic Add) 연산을 활용한다. 비록 원자적 연산에 약간의 오버헤드가 있지만, 병렬성을 높여 얻는 이득이 훨씬 크기 때문에 전체적인 학습 속도는 향상된다.1
이러한 병렬화 전략의 변화는 긴 시퀀스를 처리할 때 GPU 점유율을 획기적으로 개선하며, 특히 배치 크기가 작은 추론(Inference) 단계나 파인튜닝(Fine-tuning) 단계에서 큰 성능 향상을 가져온다.7
4. 작업 분할(Work Partitioning)의 혁신: Split-K에서 Split-Q로
병렬화가 GPU의 SM들에게 작업을 어떻게 분배할 것인가의 문제라면, 작업 분할(Work Partitioning)은 하나의 SM(스레드 블록) 내부에서 여러 개의 **워프(Warp)**들이 어떻게 협력할 것인가의 문제다. FlashAttention-2는 이 내부 작업 분할 방식을 완전히 뒤집음으로써 불필요한 동기화 비용을 제거했다.
graph TD
subgraph "FlashAttention-1: Split-K (동기화 필수)"
K_Input["입력: Q (Shared), K/V (Split)"]
K_Warp1["Warp 1: Q x K1"]
K_Warp2["Warp 2: Q x K2"]
K_Input --> K_Warp1 & K_Warp2
K_Warp1 --> K_Part1["부분 합 (Partial Sum)"]
K_Warp2 --> K_Part2["부분 합 (Partial Sum)"]
K_Part1 & K_Part2 -->|"공유 메모리 쓰기 (Write Shared Mem)"| K_Shared["공유 메모리 (SMEM)"]
K_Shared -->|"동기화 대기 (Barrier Sync)"| K_Sync["모든 워프 대기"]
K_Sync -->|"값 읽기 & 합산 (Read & Reduce)"| K_Final["최종 합산 (Final Reduction)"]
end
subgraph "FlashAttention-2: Split-Q (통신 없음)"
Q_Input["입력: Q (Split), K/V (Shared/Streamed)"]
Q_Warp1["Warp 1: Q1 x All_K"]
Q_Warp2["Warp 2: Q2 x All_K"]
Q_Input --> Q_Warp1 & Q_Warp2
Q_Warp1 -->|"독립적 계산 (Independent Calc)"| Q_Out1["완전한 결과 O1 (Full Output)"]
Q_Warp2 -->|"독립적 계산 (Independent Calc)"| Q_Out2["완전한 결과 O2 (Full Output)"]
Q_Out1 -->|"즉시 HBM 기록 (Direct Write)"| Q_Final1["종료 (Finish)"]
Q_Out2 -->|"즉시 HBM 기록 (Direct Write)"| Q_Final2["종료 (Finish)"]
end
style K_Sync fill:#ff9999,stroke:#333
style Q_Final1 fill:#99ff99,stroke:#333
style Q_Final2 fill:#99ff99,stroke:#333
4.1 FlashAttention-1: Split-K (Sliced-K) 방식
FlashAttention-1은 Split-K 방식을 사용했다. 이 방식에서는 쿼리 행렬 Q가 모든 워프에 의해 공유되고, 키 K와 값 V 행렬이 여러 워프에 나뉘어(Split) 할당된다.
각 워프는 자신에게 할당된 K의 일부분과 전체 Q를 곱하여 QK^T의 부분적인 결과(Partial Result)를 계산한다. 그 후 V와 곱하여 최종 출력 O의 부분합(Partial Sum)을 얻는다.
문제는 이 ’부분합’들이 최종 결과가 아니라는 점이다. 서로 다른 워프들이 계산한 값들을 모두 더해야만 올바른 어텐션 출력을 얻을 수 있다. 따라서 워프들은 자신의 중간 결과를 공유 메모리(Shared Memory)에 기록하고, 모든 워프가 작업을 마칠 때까지 기다린(Synchronization Barrier) 후, 다시 값을 읽어와 합산(Reduction)해야 한다.
이 과정에서 발생하는 잦은 공유 메모리 읽기/쓰기와 동기화 대기 시간은 연산 파이프라인의 흐름을 끊고 성능을 저하시키는 주요 원인이었다.1
4.2 FlashAttention-2: Split-Q 방식
FlashAttention-2는 Split-Q 방식을 채택하여 이 문제를 해결했다. 이 방식에서는 K와 V 행렬이 모든 워프에 의해 공유되고, Q 행렬이 여러 워프에 나뉘어 할당된다.
- 동작 원리: 각 워프는 자신에게 할당된 Q의 일부분(예: 특정 행들)을 가져와 전체 K와 곱셈을 수행한다. 이렇게 계산된 어텐션 점수(Score) 행렬은 Q의 해당 행에 대한 온전한 값이다.
- 동기화 제거: 이후 전체 V와 곱셈을 수행하여 얻은 출력값 역시 최종 출력 행렬 O의 해당 행에 대한 완성된 값이다. 즉, 다른 워프의 결과와 더하거나 합칠 필요가 없다.
- 결과: 워프 간의 통신이 전혀 필요 없게 된다(No Communication). 각 워프는 계산이 끝나는 즉시 결과를 공유 메모리를 거치지 않고 글로벌 메모리(HBM)에 기록하거나 다음 연산을 수행할 수 있다.
이러한 Split-Q 방식은 포워드 패스에서 공유 메모리 접근과 동기화 비용을 사실상 제거하여 획기적인 속도 향상을 가능하게 했다.1 다만 역전파 과정에서는 Q, K, V 및 그 기울기들 간의 복잡한 의존성으로 인해 여전히 일부 동기화가 필요하지만, 이 경우에도 Split-K 방식 대비 공유 메모리 사용량은 줄어든다.4
다음 표는 두 방식의 차이를 요약한 것이다.
| 비교 항목 | FlashAttention-1 (Split-K) | FlashAttention-2 (Split-Q) |
|---|---|---|
| 분할 대상 (Split) | Key (K), Value (V) | Query (Q) |
| 공유 대상 (Shared) | Query (Q) | Key (K), Value (V) |
| 워프의 계산 결과 | 최종 값의 부분합 (Partial Sum) | 최종 값의 일부분 (Output Slice) |
| 워프 간 통신 | 필수 (결과 합산 위해) | 불필요 (Forward Pass 시) |
| 공유 메모리 부하 | 높음 (중간 결과 Read/Write) | 낮음 (입력 데이터 로딩 용도) |
| 동기화 (Barrier) | 빈번함 | 최소화 또는 없음 |
5. 구현의 기술적 토대: CUTLASS와 하드웨어 추상화
FlashAttention-2의 알고리즘적 혁신을 실제 하드웨어 성능으로 연결한 것은 정교한 소프트웨어 구현 능력이다. 저자 Tri Dao는 FlashAttention-2를 구현하기 위해 기존 코드를 완전히 폐기하고 처음부터 다시 작성(Rewrite)하였으며, 이 과정에서 NVIDIA의 CUTLASS 라이브러리를 핵심적으로 활용했다.1
5.1 CUTLASS 3.x와 CuTe
CUTLASS(CUDA Templates for Linear Algebra Subroutines)는 CUDA C++ 템플릿 라이브러리로, 텐서 코어를 활용한 고성능 행렬 연산을 추상화하여 제공한다. 특히 FlashAttention-2 개발에는 CUTLASS 3.0 버전과 그 안에 포함된 레이아웃 대수 라이브러리인 CuTe가 사용되었다.
- 추상화의 힘: CuTe는 복잡한 다차원 텐서의 레이아웃과 메모리 접근 패턴을 직관적으로 정의할 수 있게 해준다. 이를 통해 개발자는 스레드 블록, 워프, 개별 스레드 수준에서 데이터가 어떻게 이동하고 분배되는지를 정밀하게 제어할 수 있다. FlashAttention-2가 복잡한 Split-Q 파티셔닝과 파이프라이닝을 구현할 수 있었던 것은 이러한 강력한 추상화 도구 덕분이다.4
- 파이프라이닝 최적화: A100 GPU는 비동기 복사(Asynchronous Copy,
cp.async) 명령어를 지원한다. FlashAttention-2는 이를 활용하여 글로벌 메모리에서 공유 메모리로 데이터를 복사하는 작업과, 텐서 코어가 연산을 수행하는 작업을 중첩(Overlap)시킨다. 즉, 연산이 진행되는 동안 다음 데이터를 미리 가져오는 파이프라이닝을 통해 메모리 지연 시간을 숨긴다(Latency Hiding). CUTLASS는 이러한 소프트웨어 파이프라이닝을 구현하기 위한 빌딩 블록을 제공한다.8
5.2 Triton 및 기타 백엔드 지원
FlashAttention-2의 아이디어 중 일부, 특히 시퀀스 병렬화는 OpenAI가 개발한 GPU 프로그래밍 언어인 Triton 커뮤니티(Phil Tillet 등)와의 교류에서 영감을 받았다.1 현재 FlashAttention-2 라이브러리는 CUDA 구현뿐만 아니라 AMD ROCm 지원을 위한 Triton 백엔드도 실험적으로 제공하고 있다.9 이는 FlashAttention-2가 특정 하드웨어에 종속되지 않고 다양한 AI 가속기로 확장될 수 있는 가능성을 보여준다.
6. 성능 평가 및 벤치마크 분석
FlashAttention-2의 성능 향상은 단순한 수치를 넘어 AI 모델 학습의 경제성을 변화시키는 수준이다. 다음은 주요 GPU 아키텍처에서의 성능 분석 결과다.
graph TD
subgraph "A100 GPU 성능 (Performance)"
A1["이론적 최대 성능 (Peak)"] ---|"312 TFLOPs"| A_Peak
A2["FlashAttention-2"] ---|"225 TFLOPs (72% Util)"| A_FA2
A3["FlashAttention-1"] ---|"124 TFLOPs"| A_FA1
A_FA1 -->|"2배 가속 (2x Speedup)"| A_FA2
end
subgraph "H100 GPU 성능과 한계 (Limit)"
H1["절대 속도 (Absolute Speed)"] ---|"338 TFLOPs"| H_Speed
H2["이론 성능 대비 효율 (Efficiency)"] ---|"35% (Low Util)"| H_Util
H_Util -->|"원인 (Cause)"| H_Reason["TMA 미사용 (No TMA)<br/>WGMMA 미활용"]
H_Reason -->|"차세대 연구 (Next Gen)"| FA3["FlashAttention-3"]
end
6.1 A100 GPU (Ampere) 성능 분석
A100 80GB SXM4 GPU에서 FlashAttention-2는 놀라운 성능 지표를 기록했다.
- 최대 처리 속도: FP16/BF16 정밀도 기준, FlashAttention-2는 약 225 TFLOPs/s의 연산 속도를 달성했다. 이는 A100의 이론적 최대 성능(Peak Throughput)인 312 TFLOPs/s의 약 **72%**에 해당한다.1
- GEMM과의 비교: 통상적으로 고도로 최적화된 행렬 곱(GEMM) 라이브러리(cuBLAS 등)가 80-90%의 효율을 보이는 것을 고려할 때, 복잡한 어텐션 로직을 포함하면서도 70% 이상의 효율을 달성한 것은 어텐션 연산이 더 이상 ‘비효율적인’ 연산이 아님을 증명한다.
- 속도 향상: 기존 FlashAttention-1이 약 124 TFLOPs/s를 기록한 것과 비교하면 약 **2배(2x)**의 속도 향상이다. 표준 PyTorch 어텐션 구현과 비교하면 최대 9배까지 빠르다.10
6.2 H100 GPU (Hopper) 성능 분석
H100 GPU에서의 성능은 절대적인 수치에서는 크게 증가했으나, 상대적인 효율성 측면에서는 과제를 남겼다.
- 절대 성능: 시퀀스 길이 16k, 헤드 차원 128 기준 338 TFLOPs/s를 기록했다.1 이는 A100에서의 성능을 훨씬 상회하며, FlashAttention-1(139 TFLOPs/s) 대비 약 2.4배 빠르다.
- 효율성 문제: H100 SXM5의 FP16 텐서 코어 이론 성능은 989 TFLOPs/s에 달한다. 따라서 338 TFLOPs/s는 이론 성능의 약 **35%**에 불과한 활용률이다.11
- 원인 분석: H100의 연산 유닛은 A100 대비 3배 이상 빨라졌지만, 메모리 대역폭이나 비 행렬 곱 연산 유닛의 속도 향상은 그에 미치지 못했다. 또한 FlashAttention-2는 H100의 새로운 기능인 TMA(Tensor Memory Accelerator)나 WGMMA(Warp Group MMA) 명령어를 완전히 활용하지 못했다. 이는 후속작인 FlashAttention-3가 등장하게 된 직접적인 배경이 된다.14
6.3 시퀀스 길이에 따른 확장성
다음 표는 시퀀스 길이 변화에 따른 A100 GPU에서의 속도(TFLOPs/s) 변화를 보여준다 (헤드 차원 128 기준).4
| 시퀀스 길이 (Sequence Length) | FlashAttention-1 (TFLOPs/s) | FlashAttention-2 (TFLOPs/s) | 속도 향상 (Speedup) |
|---|---|---|---|
| 2k | ~100 | ~180 | ~1.8x |
| 4k | ~95 | ~195 | ~2.0x |
| 8k | ~90 | ~210 | ~2.3x |
| 16k | ~82 | ~203 | ~2.5x |
시퀀스 길이가 길어질수록 FlashAttention-1의 성능은 점차 하락하는 반면, FlashAttention-2는 시퀀스 병렬화 덕분에 높은 성능을 유지하거나 오히려 점유율(Occupancy)이 높아져 성능이 향상되는 경향을 보인다. 이는 긴 문맥을 다루는 최신 LLM 트렌드에 완벽하게 부합하는 특성이다.
7. 한계점과 차세대 최적화(FlashAttention-3)로의 연결
FlashAttention-2는 A100 세대의 GPU에서 어텐션 연산 효율을 극한으로 끌어올린 걸작이다. 그러나 기술의 발전은 멈추지 않으며, 새로운 하드웨어는 새로운 최적화 기법을 요구한다.
- H100에서의 낮은 활용률: 앞서 언급했듯 H100 GPU에서는 35%의 활용률에 그쳤다. H100은 비동기성을 더욱 강화한 TMA와 워프 그룹(Warp Group) 단위의 연산을 지원하는데, FlashAttention-2의 구조는 이러한 기능을 수용하기에 한계가 있었다.
- FP8 정밀도 지원: H100은 FP8 텐서 코어 연산을 지원하여 이론 성능을 2배 더 높일 수 있다. 그러나 FP8 연산은 수치적 불안정성을 동반하므로, 더욱 정교한 알고리즘적 처리가 필요하다.13
이러한 한계는 FlashAttention-3의 개발로 이어졌다. FlashAttention-3는 워프 특수화(Warp Specialization)와 생산자-소비자(Producer-Consumer) 모델, 그리고 FP8 지원을 통해 H100에서 75% 이상의 활용률(약 740 TFLOPs/s)을 달성하게 된다.13 하지만 FlashAttention-3의 눈부신 성과도 FlashAttention-2가 확립한 ‘IO 인식’, ‘시퀀스 병렬화’, ’Split-Q 작업 분할’이라는 견고한 토대가 있었기에 가능했다.
8. 결론
FlashAttention-2는 트랜스포머 모델의 연산 효율화 역사에서 중요한 변곡점을 차지한다. O(N^2)이라는 알고리즘적 복잡도는 수학적으로 극복하기 어렵지만, 하드웨어 아키텍처에 대한 깊은 이해를 바탕으로 상수항(Constant Factor)을 극도로 최소화함으로써 실질적인 비용을 해결 가능한 수준으로 낮추었다.
특히 GPU 메모리 계층 구조에 맞춘 IO 최적화뿐만 아니라, 스레드 스케줄링과 워프 간 통신 패턴까지 제어하는 ’시스템 수준의 최적화’가 딥러닝 알고리즘 연구에 얼마나 큰 파급력을 미칠 수 있는지를 증명했다. FlashAttention-2 덕분에 연구자들과 기업들은 수만 토큰에 달하는 긴 문맥을 가진 모델을 현실적인 시간과 비용 내에서 학습시킬 수 있게 되었으며, 이는 곧 “Transformer Singularity“를 가속화하는 핵심 동력이 되었다.
이제 FlashAttention-2는 PyTorch, Hugging Face Transformers, xFormers 등 주요 딥러닝 라이브러리에 기본적으로 통합되어 9, 전 세계 AI 연구 개발의 보이지 않는 엔진으로서 그 역할을 묵묵히, 그러나 강력하게 수행하고 있다.
9. 참고 자료
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning - Tri Dao, https://tridao.me/publications/flash2/flash2.pdf
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning - arXiv, https://arxiv.org/abs/2307.08691
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning | OpenReview, https://openreview.net/forum?id=mZn2Xyh9Ec
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning - arXiv, https://arxiv.org/pdf/2307.08691
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, https://hazyresearch.stanford.edu/blog/2023-07-17-flash2
- Choosing between NVIDIA H100 vs A100 - Performance and Costs Considerations, https://www.ori.co/blog/choosing-between-nvidia-h100-vs-a100-performance-and-costs-considerations
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning, https://crfm.stanford.edu/2023/07/17/flash2.html
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision | Tri Dao, https://tridao.me/blog/2024/flash3/
- Dao-AILab/flash-attention: Fast and memory-efficient exact attention - GitHub, https://github.com/Dao-AILab/flash-attention
- Aman’s AI Journal • Primers • FlashAttention, https://aman.ai/primers/ai/flashattention/
- The Evolution of Flash Attention: Revolutionizing Transformer Efficiency | by Saiii - Medium, https://medium.com/@sailakkshmiallada/the-evolution-of-flash-attention-revolutionizing-transformer-efficiency-8a039918d507
- FlashAttention — one, two, three! | by Najeeb Khan - Medium, https://medium.com/@najeebkan/flashattention-one-two-three-6760ad030ae0
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision, https://www.together.ai/blog/flashattention-3
- FlashAttention 1/2/3: Transformer attention optimizations - ALLPCB, https://www.allpcb.com/allelectrohub/flashattention-123-transformer-attention-optimizations
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision - NIPS papers, https://proceedings.neurips.cc/paper_files/paper/2024/file/7ede97c3e082c6df10a8d6103a2eebd2-Paper-Conference.pdf
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision - arXiv, https://arxiv.org/abs/2407.08608