Triton Tutorial Practice: 01 Vector Addition

Triton vs Triton Language

Triton 的所有操作,可以在 / python/triton/init.py 下查看,共了 19 個操作

Triton Language 共 95 個操作

差異就在於 triton language 是實際要經過 compiler 編譯的區塊,而 triton 是在 compiler 外用來介接

Triton Language 基礎

如何標註 compiler 開始與結束

@triton.jit 裝飾子 的意思是表示後面的語法是需要經過 triton compiler 去編譯的

  1. 使用 @triton.jit 標註編譯區間
  2. 實作 JIT function
    A JIT function is launched with: fn[grid](*args, **kwargs).
    https://github.com/triton-lang/triton/blob/main/python/triton/runtime/jit.py
    這裡的 grid 也就是會需要在外面提供 grid
    並在 compile 區間 使用 tl.program_id 去取看是哪一維
  3. 修改 op 或是其他地方 call kernel function

triton.language

如何查詢定義
https://triton-lang.org/main/python-api/triton.language.html

program

  • tl.program_id()
    pid = tl.program_id(axis=0)

    可以從三個維度去取 program_id
    axis (int) – The axis of the 3D launch grid. Must be 0, 1 or 2.

memory

add function

def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output

n_elements 是 elements的數量
grid 是 一個lamda 函式
用來定義了 add_kernel() 所需要傳入的 grid function
block size 去除以元素數量去看一個 block 中有幾個元素

grid(meta) 實際上就是回傳 (num_blocks,)
add_kernel[(num_blocks,)](x, y, output, n_elements, BLOCK_SIZE=1024)

add_kernel function

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

add_kernel 會被 compiler 編譯後平行的運算
所以會需要使用 tl.program_id 來看「現在執行的是哪一個運算」

問題

跟 pytorch 原生的實際差異

Triton 在處理的事情是以 Block 為基礎分割運算,
使用 program_id 來標註對應的 index,這樣編譯出來的運算就可以用 block 為單位來分割並平行運算來加速

BLOCK_SIZE 是什麼

block size 是分割 memory 的單位,意指 memory 存取時的基礎單位,以每個 block 為基礎來分割運算

block size 指定成 1024 的意思

其實這邊應該是省略了優化 block size 的部分
不同的 element size 應該會有不同的最佳 block size

會需要使用 autotune 之類的方式在 compile 前指定 block size

@triton.autotune(configs = [
    triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
    triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])

Test & Benchmark 框架

@triton.testing.perf_report()
https://triton-lang.org/main/python-api/generated/triton.testing.perf_report.html#triton.testing.perf_report

Reference

https://clay-atlas.com/blog/2024/01/28/openai-triton-vector-addition/
https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c
https://isamu-website.medium.com/understanding-triton-tutorials-part-2-f6839ce50ae7
http://giantpandacv.com/project/%E9%83%A8%E7%BD%B2%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%BC%96%E8%AF%91%E5%99%A8/OpenAI%20Triton%20MLIR%20%E7%AC%AC%E4%B8%80%E7%AB%A0%20Triton%20DSL/

關於

AI Computing / 武術 / 登山 / IT / - 貪多而正努力咀嚼的人生小吃貨