add function
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
n_elements 是 elements的數量
grid 是 一個lamda 函式
用來定義了 add_kernel() 所需要傳入的 grid function
block size 去除以元素數量去看一個 block 中有幾個元素
grid(meta) 實際上就是回傳 (num_blocks,)
add_kernel[(num_blocks,)](x, y, output, n_elements, BLOCK_SIZE=1024)
add_kernel function
add_kernel 可以算是最簡單的 kernel 格式
本質的程式只有 output= x+y 這行
前方跟後方的部分是存取 GPU memory 的位置計算
主要由幾個要素組成
- 從哪裡開始 bloc_start
- 一格多大
- 怎麼樣跳下一格 tl.program_id(axis=0)
pid = tl.program_id(offest)
BLOCK_SIZE
所以 tl.arange 跟 tl.program_id 的 0 指的都是
這一算法的跳躍方式
以下影片中的 add_kernel 更加簡化,甚至沒有用到 mask
https://www.youtube.com/watch?v=etlFyqSsmL0
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
add_kernel 會被 compiler 編譯後平行的運算
所以會需要使用 tl.program_id 來看「現在執行的是哪一個運算」
問題
跟 pytorch 原生的實際差異
Triton 在處理的事情是以 Block 為基礎分割運算,
使用 program_id 來標註對應的 index,這樣編譯出來的運算就可以用 block 為單位來分割並平行運算來加速
BLOCK_SIZE 是什麼
block size 是分割 memory 的單位,意指 memory 存取時的基礎單位,以每個 block 為基礎來分割運算
block size 指定成 1024 的意思
其實這邊應該是省略了優化 block size 的部分
不同的 element size 應該會有不同的最佳 block size
會需要使用 autotune 之類的方式在 compile 前指定 block size
@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
Test & Benchmark 框架
@triton.testing.perf_report()
https://triton-lang.org/main/python-api/generated/triton.testing.perf_report.html#triton.testing.perf_report
Reference
https://clay-atlas.com/blog/2024/01/28/openai-triton-vector-addition/
https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c
https://isamu-website.medium.com/understanding-triton-tutorials-part-2-f6839ce50ae7
http://giantpandacv.com/project/%E9%83%A8%E7%BD%B2%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%BC%96%E8%AF%91%E5%99%A8/OpenAI%20Triton%20MLIR%20%E7%AC%AC%E4%B8%80%E7%AB%A0%20Triton%20DSL/