這個練習的目的有二:
- The benefits of kernel fusion for bandwidth-bound operations.
- Reduction operators in Triton.
bandwidth-bound operation
在有些時候也稱作 memory-bound
目前考慮單機器中
AI 運算的情境下通常有三種問題影響加速:
- computing bound – 負責算的速度不夠
- bandwith bound – 主要在等計算前後搬移資料的時間
- memory bound – GPU memory 不夠用
為什麼要 fuse/redcution operators?
假設原本有兩個運算 A 跟 B 要做
載入 input of A -> A 運算 -> 儲存output of A
載入 input of B -> B 運算 -> 儲存output of B
如果 output of A 跟 input of B 是同一個,那直接合在一起做的話就可以減少 兩次等待 memory 的操作了
Softmax
將一個 vector 的值轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。
而 softmax 由於運算時間佔比很小,會是個 memory-bound 的 operation
而將一個 matrix 做 softmax 的時候會要指定是要對row做還是對 column 來做
Fused Softmax
原始的 pytorch softmax 處理一個 MxN 矩陣的時候
會需要讀取 5MN+2M ,寫入 3MN+2M
在這個練習中目標就是只讀取一次 矩陣 MN 並寫入一次 MN
加速的基礎邏輯
在 GPU 的運算架構之下,概念上如果可以降低各個運算 core 對於資料的共享程度,就可以盡量讓那個 core 的運算在 L2 Cache 就完成,這樣就可以減少 Dram 的 IO,從而獲得加速
https://youtu.be/YhPbVSsUkhs?t=17
而 tl.load 的操作就是操作在 L2 Cache 中 (?)
x = tl.load(x_ptr + offsets, mask=mask)
三個比較基準的定義
假設 input 是一個 M*N matrix
-
Torch(jit)
- 理應是最慢的,因為他是在這個練習中使用 torch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
@torch.jit.script def naive_softmax(x):
x_max = x.max(dim=1)[0] # read MN elements ; write M elements
z = x – x_max[:, None] # read MN + M elements ; write MN elements
numerator = torch.exp(z) # read MN elements ; write MN elements
denominator = numerator.sum(dim=1) # read MN elements ; write M elements
ret = numerator / denominator[:, None] # read MN + M elements ; write MN elements
return ret這裡每一行的運算都可以視作一個 kernel 因為是用 torch script 寫出來的運作 所以每一次的運算都是從 DRAM 讀出來、運算完寫回 DRAM read 5MN+2M elements write 3MN+2M elements
- 理應是最慢的,因為他是在這個練習中使用 torch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
-
Torch(Native)
- torch native 是指使用使用 torch 提供的 torch.softmax()
-
Triton
- 這個練習中使用 triton寫的 softmax(x)
所以使用 triton.jit compile
將資料讀入後全部運算完才存回 DRAM@triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): row_idx = tl.program_id(0) row_start_ptr = input_ptr + row_idx * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
mv MN from dram to GPU core
write MN back to Dram
- 這個練習中使用 triton寫的 softmax(x)
設定 index
softmax_kernel[(n_rows, )]
對應取 program_id 取 axis=0
row_idx = tl.program_id(0)
用 row 來當作切分的依據
input_row_stride 就會是 1 row 有幾個 element https://pytorch.org/docs/stable/generated/torch.Tensor.stride.html
row_start_ptr = input_ptr + row_idx * input_row_stride
取得正要運算的 memory 位置
- tl.arange()
col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets
以下運算皆為整個 row 進行運算
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
num_warps
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
Reference
- https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c
- https://clay-atlas.com/blog/2024/01/29/openai-triton-note-2-fused-softmax/
- https://medium.com/@e0928021388/%E7%AA%81%E7%A0%B4-transformers-%E7%9A%84%E9%80%9F%E5%BA%A6%E7%93%B6%E9%A0%B8-flash-attention-%E4%BB%8B%E7%B4%B9-28c1bc667fd9