FlashAttention 123 (本文編輯中

Slided Link

基礎知識

Standard Attention Implementation


Chain rule

Backward 時的順序:

Flash Attention1

問題定義

  • 需要 intermediate activations 來幫助我們在 backward 的時候透過 chain rule 計算 gradients
  • 由於 SRAM 本身不夠大,而 softmax 這種需要計算 sum 的 operation,需要整個 row 的 element 都到齊後才可以計算,使得我們沒有辦法 apply 一些 divide and conquery 的 algorithm ,更使得我們沒有辦法把所有運算一口氣在 SRAM 當中計算完

手段1 Tiling (online-softmax)

  • 原本的做法:

    • 直接計算 softmax,需要等每一個 row 計算完後才能得到分母的 exponential summation。
  • 改進的做法:

    • 將 S row vector 和 V col vector 切成多個 sub blocks,各自計算出各自的 value
    • softmax

手段2 Recomputation

不儲存 intermediate activations 而是在有需要的時候再重新計算

  • 原本的做法:
    • forward 時將 S 和 P 這兩個 N*N 的 matrix 在 SRAM 計算完後存入 HBM
    • backward 時再將兩個 Matrix 從 HBM load 到 SRAM
  • 改進的做法:
    • backward 有需要時從 SRAM 重新計算一次
  • 相比於儲存 intermediate 的 I/O,直接拿在 cache 中的資料重算更快

實驗設計

結果

Flash Attention 2

問題定義

Flash Attention 還有一些常見的操作並沒有優化到,像是 GEMM。在 A100實驗中 forward pass 的時候只達到了理論最大值的 30~50%,backward pass 只有 25-35%。

has suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.

手段

better parallelism and work partitioning

調整 online-softmax實作

  • Forward Pass

    • scale the final output at end of loop
    • store the logsumexp
  • Backward Pass

    • use the row-wise logsumexp 𝐿 instead of both the row-wise max and row-wise sum of exponentials in the softmax.

      Parallelism 平行化運算


      會有三種維度可以進行平行運算

  • sequence length dimension

  • batch dimension

  • number of heads dimension

  • forward pass

    • each worker takes care of a block of rows of the attention matrix
  • backword pass
    *each worker takes care of a block of columns of the attention matrix

Wrap

use 4 or 8 warps per thread block

  • Forward Pass

  • Backward Pass

實驗設計

結果

Flash Attention 3

問題定義

在 H100 的架構下,uitilization 不佳

手段

  • Producer-Consumer asynchrony
  • Hiding softmax under asynchronous block-wise GEMMs
  • Hardware-accelerated low-precision GEMM

參考資料

https://hao.cnyes.com/post/97880

關於

AI Computing / 武術 / 登山 / IT / - 貪多而正努力咀嚼的人生小吃貨