FlashAttention 123 (本文編輯中
基礎知識
Standard Attention Implementation
Chain rule
Backward 時的順序:
Flash Attention1
問題定義
- 需要 intermediate activations 來幫助我們在 backward 的時候透過 chain rule 計算 gradients
- 由於 SRAM 本身不夠大,而 softmax 這種需要計算 sum 的 operation,需要整個 row 的 element 都到齊後才可以計算,使得我們沒有辦法 apply 一些 divide and conquery 的 algorithm ,更使得我們沒有辦法把所有運算一口氣在 SRAM 當中計算完
手段1 Tiling (online-softmax)
-
原本的做法:
- 直接計算 softmax,需要等每一個 row 計算完後才能得到分母的 exponential summation。
-
改進的做法:
- 將 S row vector 和 V col vector 切成多個 sub blocks,各自計算出各自的 value
- softmax
手段2 Recomputation
不儲存 intermediate activations 而是在有需要的時候再重新計算
- 原本的做法:
- forward 時將 S 和 P 這兩個 N*N 的 matrix 在 SRAM 計算完後存入 HBM
- backward 時再將兩個 Matrix 從 HBM load 到 SRAM
- 改進的做法:
- backward 有需要時從 SRAM 重新計算一次
- 相比於儲存 intermediate 的 I/O,直接拿在 cache 中的資料重算更快
實驗設計
結果
Flash Attention 2
問題定義
Flash Attention 還有一些常見的操作並沒有優化到,像是 GEMM。在 A100實驗中 forward pass 的時候只達到了理論最大值的 30~50%,backward pass 只有 25-35%。
has suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.
手段
better parallelism and work partitioning
調整 online-softmax實作
-
Forward Pass
- scale the final output at end of loop
- store the logsumexp
-
Backward Pass
- use the row-wise logsumexp 𝐿 instead of both the row-wise max and row-wise sum of exponentials in the softmax.
Parallelism 平行化運算
會有三種維度可以進行平行運算
- use the row-wise logsumexp 𝐿 instead of both the row-wise max and row-wise sum of exponentials in the softmax.
-
sequence length dimension
-
batch dimension
-
number of heads dimension
-
forward pass
- each worker takes care of a block of rows of the attention matrix
-
backword pass
*each worker takes care of a block of columns of the attention matrix
Wrap
use 4 or 8 warps per thread block
- Forward Pass
- Backward Pass
實驗設計
結果
Flash Attention 3
問題定義
在 H100 的架構下,uitilization 不佳
手段
- Producer-Consumer asynchrony
- Hiding softmax under asynchronous block-wise GEMMs
- Hardware-accelerated low-precision GEMM