Triton Tutorial Practice: 02 Fused Softmax

這個練習的目的有二:

  • The benefits of kernel fusion for bandwidth-bound operations.
  • Reduction operators in Triton.

bandwidth-bound operation

在有些時候也稱作 memory-bound

目前考慮單機器中
AI 運算的情境下通常有三種問題影響加速:

  1. computing bound – 負責算的速度不夠
  2. bandwith bound – 主要在等計算前後搬移資料的時間
  3. memory bound – GPU memory 不夠用

為什麼要 fuse/redcution operators?

假設原本有兩個運算 A 跟 B 要做
載入 input of A -> A 運算 -> 儲存output of A
載入 input of B -> B 運算 -> 儲存output of B
如果 output of A 跟 input of B 是同一個,那直接合在一起做的話就可以減少 兩次等待 memory 的操作了

Softmax

將一個 vector 的值轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。

而 softmax 由於運算時間佔比很小,會是個 memory-bound 的 operation

而將一個 matrix 做 softmax 的時候會要指定是要對row做還是對 column 來做

Fused Softmax

原始的 pytorch softmax 處理一個 MxN 矩陣的時候
會需要讀取 5MN+2M ,寫入 3MN+2M

在這個練習中目標就是只讀取一次 矩陣 MN 並寫入一次 MN

加速的基礎邏輯

在 GPU 的運算架構之下,概念上如果可以降低各個運算 core 對於資料的共享程度,就可以盡量讓那個 core 的運算在 L2 Cache 就完成,這樣就可以減少 Dram 的 IO,從而獲得加速

file
https://youtu.be/YhPbVSsUkhs?t=17
而 tl.load 的操作就是操作在 L2 Cache 中 (?)

x = tl.load(x_ptr + offsets, mask=mask)

三個比較基準的定義

假設 input 是一個 M*N matrix

  • Torch(jit)

    • 理應是最慢的,因為他是在這個練習中使用 torch script 手刻沒有經過任何加速的方式的 naive_softmax(x)
      
      @torch.jit.script
      def naive_softmax(x):

    x_max = x.max(dim=1)[0] # read MN elements ; write M elements
    z = x – x_max[:, None] # read MN + M elements ; write MN elements
    numerator = torch.exp(z) # read MN elements ; write MN elements
    denominator = numerator.sum(dim=1) # read MN elements ; write M elements
    ret = numerator / denominator[:, None] # read MN + M elements ; write MN elements
    return ret

    
    這裡每一行的運算都可以視作一個 kernel
    因為是用 torch script 寫出來的運作
    所以每一次的運算都是從 DRAM 讀出來、運算完寫回 DRAM
    
    read 5MN+2M elements
    write 3MN+2M elements
  • Torch(Native)

    • torch native 是指使用使用 torch 提供的 torch.softmax()
  • Triton

    • 這個練習中使用 triton寫的 softmax(x)
      所以使用 triton.jit compile
      將資料讀入後全部運算完才存回 DRAM

      @triton.jit
      def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
      row_idx = tl.program_id(0)
      row_start_ptr = input_ptr + row_idx * input_row_stride
      col_offsets = tl.arange(0, BLOCK_SIZE)
      input_ptrs = row_start_ptr + col_offsets
      row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
      row_minus_max = row - tl.max(row, axis=0)
      numerator = tl.exp(row_minus_max)
      denominator = tl.sum(numerator, axis=0)
      softmax_output = numerator / denominator
      # Write back output to DRAM
      output_row_start_ptr = output_ptr + row_idx * output_row_stride
      output_ptrs = output_row_start_ptr + col_offsets
      tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

      mv MN from dram to GPU core
      write M
      N back to Dram

設定 index

  softmax_kernel[(n_rows, )]

對應取 program_id 取 axis=0

 row_idx = tl.program_id(0)

用 row 來當作切分的依據

input_row_stride 就會是 1 row 有幾個 element https://pytorch.org/docs/stable/generated/torch.Tensor.stride.html

row_start_ptr = input_ptr + row_idx * input_row_stride

取得正要運算的 memory 位置

  • tl.arange()
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

以下運算皆為整個 row 進行運算

    row_minus_max = row - tl.max(row, axis=0)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

num_warps

    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16
    # Allocate output
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))

tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

Reference