Triton Tutorial Practice: 06 Flash-Attention & Triton Implement
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
關於 FlashAttention 的發展請看
FlashAttention123
為什麼要 Fuse Attention
簡單的答案:為了要節省 GPU 記憶體的使用
複雜一點的答案:為了要節省 Self-Attesion 使用記憶體量
Self-Attention 是什麼,為什麼是 bottleneck
由於 1~3步中對於 HBM 的存取,導致整體成為 memory bound 的task
(ND + NN)
N >> D
S 和 P 的 memory access (N*N 的複雜度) 便是整體 self-attention 的 bottleneck!
所以 Attention 成為了過程中的 bottleneck
如何避免 S、P進出 HBM 就計算到 O
S跟P 需要被儲存的兩個原因
- softmax 需要等待 row 算完
- backpropagation 需要 intermediate activations
跳過 S跟P的計算,P使用 tiling 的方式直接計算 O 的部分
Recomputation: 不儲存 intermediate activations
因為 Memory 存取很慢,
所以跳過 Memory存取 等到需要用的時候直接重算
這樣就節省了兩次 Memory 存取,最後反而比存起來還要快
教學實作
在開始之前…
這篇教學如果使用一般方式安裝的 triton 可能會發生
'CudaDriver' object has no attribute 'active'
的錯誤訊息
參照 另一篇 issue 中的狀況 https://github.com/triton-lang/triton/issues/4014
改安裝 nightly 的版本可以迴避,但這個 tutorial 最後還是會出現 index error (雖然第一張結果是跑得出來啦)
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
但 改用 .py 檔的 script 就可以跑完流程,會跟網頁上的圖片不一樣,但這看起來才是正常的
forward pass
STAGE ==1 :
STAGE ==2 :
qk_scale:
step1&2 將 Q,K,V,O 切成對應分割
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
step3: main loop
step6: inner loop
執行 _attn_fwd_inner()
step7: load K,V
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
step8: 計算 S (qk)
step9: 計算 mij (m_ij)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
step10:計算 O
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
if fp8_v:
p = p.to(tl.float8e5)
else:
p = p.to(tl.float16)
acc = tl.dot(p, v, acc)
https://triton-lang.org/main/python-api/generated/triton.language.make_block_ptr.html
backward pass
Reference
社群討論
How close could triton get to FlashAttention v3 performance? https://tridao.me/blog/2024/flash3
guesstimate is maybe within 20-30% with appropriately tuned tile size. There are some Triton numbers in this blog post that do not make sense, such as causal being 30% slower than non-causal for FP8 at 16k context