Triton Tutorial Practice: 02 Fused Softmax

Softmax

將多分類的輸出轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。

而 softmax 由於運算時間佔比很小,會是個 memory-bound 的 operation

這個練習目的就是減少 memory 存取的方式來提供加速

Fused Softmax

加速的基礎邏輯

在 GPU 的運算架構之下,概念上如果可以降低各個運算 core 對於資料的共享程度,就可以盡量讓那個 core 的運算在 L2 Cache 就完成,這樣就可以減少 Dram 的 IO,從而獲得加速

file
https://youtu.be/YhPbVSsUkhs?t=17
而 tl.load 的操作就是操作在 L2 Cache 中 (?)

x = tl.load(x_ptr + offsets, mask=mask)

三個比較基準的定義

假設 input 是一個 M*N matrix

  • Torch(jit)

    • torch jit 理應是最慢的,因為他是在這個練習中使用 ttorch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
      
      @torch.jit.script
      def naive_softmax(x):

    x_max = x.max(dim=1)[0] # read MN elements ; write M elements
    z = x – x_max[:, None] # read MN + M elements ; write MN elements
    numerator = torch.exp(z) # read MN elements ; write MN elements
    denominator = numerator.sum(dim=1) # read MN elements ; write M elements
    ret = numerator / denominator[:, None] # read MN + M elements ; write MN elements
    return ret

    
    read 5MN+2M elements 

write 3MN+2M elements

  • Torch(Native)

    • torch native 是指使用使用 torch 提供的 torch.softmax()
  • Triton

    • 這個練習中使用 triton 的 softmax(x)
      @triton.jit
      def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
      row_idx = tl.program_id(0)
      row_start_ptr = input_ptr + row_idx * input_row_stride
      col_offsets = tl.arange(0, BLOCK_SIZE)
      input_ptrs = row_start_ptr + col_offsets
      row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
      row_minus_max = row - tl.max(row, axis=0)
      numerator = tl.exp(row_minus_max)
      denominator = tl.sum(numerator, axis=0)
      softmax_output = numerator / denominator
      # Write back output to DRAM
      output_row_start_ptr = output_ptr + row_idx * output_row_stride
      output_ptrs = output_row_start_ptr + col_offsets
      tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

      mv MN from dram to GPU core
      write M
      N back to Dram

設定 index

  softmax_kernel[(n_rows, )]

對應取 program_id 取 axis=0

 row_idx = tl.program_id(0)

用 row 來當作切分的依據

input_row_stride 就會是 1 row 有幾個 element https://pytorch.org/docs/stable/generated/torch.Tensor.stride.html

row_start_ptr = input_ptr + row_idx * input_row_stride

取得正要運算的 memory 位置

  • tl.arange()
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

以下運算皆為整個 row 進行運算

    row_minus_max = row - tl.max(row, axis=0)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

num_warps

    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16
    # Allocate output
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))

tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

Reference

關於

AI Computing / 武術 / 登山 / IT / - 貪多而正努力咀嚼的人生小吃貨