Triton Tutorial Practice: 02 Fused Softmax
Softmax
將多分類的輸出轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。
而 softmax 由於運算時間佔比很小,會是個 memory-bound 的 operation
這個練習目的就是減少 memory 存取的方式來提供加速
Fused Softmax
加速的基礎邏輯
在 GPU 的運算架構之下,概念上如果可以降低各個運算 core 對於資料的共享程度,就可以盡量讓那個 core 的運算在 L2 Cache 就完成,這樣就可以減少 Dram 的 IO,從而獲得加速
https://youtu.be/YhPbVSsUkhs?t=17
而 tl.load 的操作就是操作在 L2 Cache 中 (?)
x = tl.load(x_ptr + offsets, mask=mask)
三個比較基準的定義
假設 input 是一個 M*N matrix
-
Torch(jit)
- torch jit 理應是最慢的,因為他是在這個練習中使用 ttorch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
@torch.jit.script def naive_softmax(x):
x_max = x.max(dim=1)[0] # read MN elements ; write M elements
z = x – x_max[:, None] # read MN + M elements ; write MN elements
numerator = torch.exp(z) # read MN elements ; write MN elements
denominator = numerator.sum(dim=1) # read MN elements ; write M elements
ret = numerator / denominator[:, None] # read MN + M elements ; write MN elements
return retread 5MN+2M elements
- torch jit 理應是最慢的,因為他是在這個練習中使用 ttorch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
write 3MN+2M elements
-
Torch(Native)
- torch native 是指使用使用 torch 提供的 torch.softmax()
-
Triton
- 這個練習中使用 triton 的 softmax(x)
@triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): row_idx = tl.program_id(0) row_start_ptr = input_ptr + row_idx * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
mv MN from dram to GPU core
write MN back to Dram
- 這個練習中使用 triton 的 softmax(x)
設定 index
softmax_kernel[(n_rows, )]
對應取 program_id 取 axis=0
row_idx = tl.program_id(0)
用 row 來當作切分的依據
input_row_stride 就會是 1 row 有幾個 element https://pytorch.org/docs/stable/generated/torch.Tensor.stride.html
row_start_ptr = input_ptr + row_idx * input_row_stride
取得正要運算的 memory 位置
- tl.arange()
col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets
以下運算皆為整個 row 進行運算
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
num_warps
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
Reference
- https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c
- https://clay-atlas.com/blog/2024/01/29/openai-triton-note-2-fused-softmax/
- https://medium.com/@e0928021388/%E7%AA%81%E7%A0%B4-transformers-%E7%9A%84%E9%80%9F%E5%BA%A6%E7%93%B6%E9%A0%B8-flash-attention-%E4%BB%8B%E7%B4%B9-28c1bc667fd9