一 前言
lz最开始找了好多教程,发现单个教程总是会忽略一下细节,让人少许摸不到头脑,当然也可能是对cuda编程不熟悉,在此查阅多方资料,给出自己的理解,不对之处,还望读者批评指正。
二 内容分析
2.1 函数定义
首先来看一下triton中的函数定义
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
- 这里的
triton.jit
基本上是告诉编译器这是一个 triton kernel 函数, 你想要定义一个 kernel, 前面就要加这么一个 Decorator. a_ptr
b_ptr
c_ptr
这三个都是 pointer. 熟悉 C 语言的肯定知道这是什么意思. 不过如果不熟悉也没关系, 概念也很简单. 假如有一个 python array, 这个 array 里面每个元素都有一个地址, 前后两个元素的地址一般也是连着的, 所以, 我们只要知道这个 array 的第一个元素的地址, 就可以顺势推出所有元素的地址. 这里的a_ptr
b_ptr
c_ptr
就是各自 tensor 的第一个元素的地址.- Triton 里面一般都是通过 pointer 这种方式来表示变量. 为什么? 其实这跟 Triton 要解决的问题有关系. Triton 本质上是要追求效率的, 而效率中很重要的一环就是 memory IO, 也就是 data load. 所以 Triton 希望用户能够 be aware of the cost of data loading, 来写出来更加高效的 kernel. 所以, 在 Triton 中, 每个数字 scalar 都需要通过他的 pointer, 手动 load 到 memory 中来. 这个在后面的代码中可以清楚的看到
- 输入中还有各种
stride
, 这个 stride 是什么意思呢. 我们可以直接从 PyTorch 来看.
a = torch.rand([3,6])
a.stride()
# (6, 1)
这里的第一个维度的 stride 是 6, 因为从 a[m, k]
的地址 到 a[m+1, k]
的地址, 中间差了 6 个元素 (具体差了多少 byte 取决于数据类型),
第二个维度的 stride 是 1, 因为从 a[m, k]
的地址 到 a[m, k+1]
的地址, 中间差了 1 个元素.
换句话说, stride 的作用就是为了更加方便的找到每个元素的 pointer (地址)
5. 剩下还有很多 tl.constexpr
这个大家可以理解成这个 kernel 的 超参数 (hyper-parameters), 这个超参数后面可以由 Triton compiler 来进行搜索不同的值. 为什么搜索这些超参数很重要呢? 其实很简单, 因为我们一个 kernel 的输入形状会变, 最终执行这个 kernel 的硬件也有不同规格 (比如不同型号的 GPU 有不同大小的 shared memory size), 所以找到一组合适的超参数是对于优化效率也是很重要的.
++++++++++++++++++++++++++++++++++++++++++++&