Triton Tutorial Practice: 03 Matrix Multiplication

學習目標

  • Block-level matrix multiplications.
  • Multi-dimensional pointer arithmetic.
  • Program re-ordering for improved L2 cache hit rate.
  • Automatic performance tuning.

matrix multiplications

用二維矩陣相乘為例,如果我們寫一個一般的 CPU 程式,
就會分別以兩個 loop 對兩個維度進行 loop,每一次執行一個元素,所以也每一次也會做一次記憶體存、取

而在 GPU 上執行矩陣相乘的時候,該做的並不是「執行loop」,而是針對記憶體的位置進行標示
一次載入一大塊記憶體,而語法的目的是為了讓不同的 thread 可以取得他要運算所需的記憶體

for m in range(0, M, BLOCK_SIZE_M):
  # Do in parallel
  for n in range(0, N, BLOCK_SIZE_N):
    acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
    for k in range(0, K, BLOCK_SIZE_K):
      a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
      b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
      acc += dot(a, b)
    C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc

pointer arithmetic

&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] =  a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] =  b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a_ptrs += BLOCK_SIZE_K * stride_ak;
b_ptrs += BLOCK_SIZE_K * stride_bk;

L2 Cache Optimizations

一個 program instance 會執行一個 Block 的運算
所以為了最佳化,會將有相關連的 block group 在一起在進行運算

以下程式並不符合,(為什麼?)

pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

參照下面的圖片,沒有特別優化的載入方式會像是 Row-major ordering 的部分
file
如果要盡量地讓「當下的記憶體都存在 Cache 中的話,要減少「停滯未使用的資料」
一個方式是用 grouped ordering 的方式

圖片下半部就是算出 group size 是3後,重新規劃突入的 row 跟 column 來做運算的方法

# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m

Automatic performance tuning

@triton.autotune 的用法是
提供一個 triton.Config 的 list 來定義 meta-parameter 跟 compilation options
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8) 就是一個 triton.Config
BLOCK_SIZE_M 就是 meta-parameter
num_stages 跟 num_warps compilation options

  • 有哪些 meta-parameter

  • 有哪些 compilation options

    • An auto-tuning key whose change in values will trigger evaluation of all the provided configs
      key=['M', 'N', 'K']

code implement

matmul_kernel

位置計算

A * B = C
A: MxK B:KxN C:MxN

pid: 標注自己是第幾個 program (thread)
所以針對每一個 pid 我們要先讓他知道自己要計算的位置在哪邊

num_pid_m: m這個 size 裡總共會有幾個 pid
num_pid_in_group: grouping 時一個 group 裡會有幾個 pid

group_id: 這個 pid 在第幾個 group

first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)

pid_m: 這個 pid 對應的 在 A 之中的 m 座標 (第幾個 block)
pid_n: 這個 pid 對應的 在 B之中的 n 座標 (第幾個 block)

stride & pointer 計算

offs_* : 這個 pid 在 block中的第幾個元素

a_ptrs: 實際展開在記憶體中的位置: A的位置+m座標位移+k座標位移

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
offs_k[None, :] < K - k * BLOCK_SIZE_K

mask

mask=offs_k[None, :] < K - k * BLOCK_SIZE_K
(offs_cm[:, None] < M) & (offs_cn[None, :] < N)

Reference