Triton 以其低门槛开发和抽象的硬件细节处理,成为开发者的优选。
对于渴望参与 Triton 开源社区建设的开发者来说,优化 Triton 算子是一条理想的路径。优化后的 Triton 算子性能有望匹敌甚至超越 PyTorch 的原生实现。
正如古人云:“工欲善其事,必先利其器”,本文将介绍 Triton 算子优化的利器——自动调优(autotune)。
Triton 算子的实现代码
一个简单的 Triton 算子的实现代码通常分为两个部分。
-
计算准备阶段:涉及输入张量的预处理(如转换为连续布局张量),计算输出张量的形状并分配内存,以及设置运行参数(如 grid 和 BLOCK_SIZE)。
-
核函数调用:在 GPU 上实现计算逻辑。
以下是一个对三维张量 inp(形状为 [M,N,K])计算在第 1 维(N 所在维)上最大值下标的例子,即 argmax 算子。下文将在本例的基础上讲述 autotune 的用法。
def argmax(inp):
# 第一部分
dim = 1 # 本例是 argmax 的一个特化实现,仅能处理 dim=1 的情形
N = shape[dim]
M = prod(shape[:dim])
K = inp.numel() // M // N
inp = inp.contiguous() # 将输入转换为连续布局张量,本例中可以简化核函数实现
shape = list(inp.shape)
shape[dim] = 1 # 第 dim 维取最大值下标,因此该维度上输出 shape 归一
out_index = torch.empty(shape, dtype=torch.int64, device=inp.device) # 分配输出
grid = lambda meta: ( # 本例 grid 使用二维,通过表达式指定
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
# 第二部分
with torch.cuda.device(inp.device):
argmax_kernel[grid](
inp,