Triton Tutorial Practice: 06 Flash-Attention & Triton Implement

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

關於 FlashAttention 的發展請看
FlashAttention123

為什麼要 Fuse Attention

簡單的答案:為了要節省 GPU 記憶體的使用
複雜一點的答案:為了要節省 Self-Attesion 使用記憶體量

Self-Attention 是什麼,為什麼是 bottleneck

Attention 機制

file
由於 1~3步中對於 HBM 的存取,導致整體成為 memory bound 的task

(ND + NN)
N >> D
S 和 P 的 memory access (N*N 的複雜度) 便是整體 self-attention 的 bottleneck!

所以 Attention 成為了過程中的 bottleneck

如何避免 S、P進出 HBM 就計算到 O

S跟P 需要被儲存的兩個原因

  1. softmax 需要等待 row 算完
  2. backpropagation 需要 intermediate activations

跳過 S跟P的計算,P使用 tiling 的方式直接計算 O 的部分

file

Recomputation: 不儲存 intermediate activations

因為 Memory 存取很慢,
所以跳過 Memory存取 等到需要用的時候直接重算
這樣就節省了兩次 Memory 存取,最後反而比存起來還要快

教學實作

在開始之前…

這篇教學如果使用一般方式安裝的 triton 可能會發生
'CudaDriver' object has no attribute 'active'的錯誤訊息

參照 另一篇 issue 中的狀況 https://github.com/triton-lang/triton/issues/4014

改安裝 nightly 的版本可以迴避,但這個 tutorial 最後還是會出現 index error (雖然第一張結果是跑得出來啦)

pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

但 改用 .py 檔的 script 就可以跑完流程,會跟網頁上的圖片不一樣,但這看起來才是正常的

forward pass

STAGE ==1 :
STAGE ==2 :
qk_scale:

step1&2 將 Q,K,V,O 切成對應分割

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, HEAD_DIM),
        order=v_order,
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)

step3: main loop

step6: inner loop
執行 _attn_fwd_inner()

step7: load K,V

    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

step8: 計算 S (qk)
step9: 計算 mij (m_ij)

       if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]
        else:
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, None]

step10:計算 O

        p = tl.math.exp2(qk)
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        # -- update output accumulator --
        acc = acc * alpha[:, None]
        # update acc
        v = tl.load(V_block_ptr)
        if fp8_v:
            p = p.to(tl.float8e5)
        else:
            p = p.to(tl.float16)
        acc = tl.dot(p, v, acc)

https://triton-lang.org/main/python-api/generated/triton.language.make_block_ptr.html

backward pass

Reference

https://medium.com/@e0928021388/%E7%AA%81%E7%A0%B4-transformers-%E7%9A%84%E9%80%9F%E5%BA%A6%E7%93%B6%E9%A0%B8-flash-attention-%E4%BB%8B%E7%B4%B9-28c1bc667fd9

社群討論

How close could triton get to FlashAttention v3 performance? https://tridao.me/blog/2024/flash3

guesstimate is maybe within 20-30% with appropriately tuned tile size. There are some Triton numbers in this blog post that do not make sense, such as causal being 30% slower than non-causal for FP8 at 16k context