Triton Tutorial Practice: 01 Vector Addition
Triton vs Triton Language
Triton 的所有操作,可以在 / python/triton/init.py 下查看,共了 19 個操作
Triton Language 共 95 個操作
差異就在於 triton language 是實際要經過 compiler 編譯的區塊,而 triton 是在 compiler 外用來介接
Triton Language 基礎
如何標註 compiler 開始與結束
@triton.jit
裝飾子 的意思是表示後面的語法是需要經過 triton compiler 去編譯的
- 使用 @triton.jit 標註編譯區間
- 實作 JIT function
A JIT function is launched with: fn[grid](*args, **kwargs).
https://github.com/triton-lang/triton/blob/main/python/triton/runtime/jit.py
這裡的 grid 也就是會需要在外面提供 grid
並在 compile 區間 使用 tl.program_id 去取看是哪一維 - 修改 op 或是其他地方 call kernel function
triton.language
如何查詢定義
https://triton-lang.org/main/python-api/triton.language.html
program
- tl.program_id()
pid = tl.program_id(axis=0)
可以從三個維度去取 program_id
axis (int) – The axis of the 3D launch grid. Must be 0, 1 or 2.
memory
-
tl.arange()
-
triton.cdiv()
向上取整數的除法,用來計算一個 block 要處理多少元素
https://triton-lang.org/main/python-api/generated/triton.language.cdiv.html#triton.language.cdiv -
Q:為啥不是 tl.cdiv()?
A: 因為這段並不是在 JIT function 之中
add function
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
n_elements 是 elements的數量
grid 是 一個lamda 函式
用來定義了 add_kernel() 所需要傳入的 grid function
block size 去除以元素數量去看一個 block 中有幾個元素
grid(meta) 實際上就是回傳 (num_blocks,)
add_kernel[(num_blocks,)](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel function
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
add_kernel 會被 compiler 編譯後平行的運算
所以會需要使用 tl.program_id 來看「現在執行的是哪一個運算」
問題
跟 pytorch 原生的實際差異
Triton 在處理的事情是以 Block 為基礎分割運算,
使用 program_id 來標註對應的 index,這樣編譯出來的運算就可以用 block 為單位來分割並平行運算來加速
BLOCK_SIZE 是什麼
block size 是分割 memory 的單位,意指 memory 存取時的基礎單位,以每個 block 為基礎來分割運算
block size 指定成 1024 的意思
其實這邊應該是省略了優化 block size 的部分
不同的 element size 應該會有不同的最佳 block size
會需要使用 autotune 之類的方式在 compile 前指定 block size
@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
Test & Benchmark 框架
@triton.testing.perf_report()
https://triton-lang.org/main/python-api/generated/triton.testing.perf_report.html#triton.testing.perf_report
Reference
https://clay-atlas.com/blog/2024/01/28/openai-triton-vector-addition/
https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c
https://isamu-website.medium.com/understanding-triton-tutorials-part-2-f6839ce50ae7
http://giantpandacv.com/project/%E9%83%A8%E7%BD%B2%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%BC%96%E8%AF%91%E5%99%A8/OpenAI%20Triton%20MLIR%20%E7%AC%AC%E4%B8%80%E7%AB%A0%20Triton%20DSL/