triton教程--matrix multiplication逐行解释

一 前言

lz最开始找了好多教程,发现单个教程总是会忽略一下细节,让人少许摸不到头脑,当然也可能是对cuda编程不熟悉,在此查阅多方资料,给出自己的理解,不对之处,还望读者批评指正。

二 内容分析

2.1 函数定义

首先来看一下triton中的函数定义



@triton.jit
def matmul_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    ACTIVATION: tl.constexpr,
):

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

  1. 这里的 triton.jit 基本上是告诉编译器这是一个 triton kernel 函数, 你想要定义一个 kernel, 前面就要加这么一个 Decorator.
  2. a_ptr b_ptr c_ptr 这三个都是 pointer. 熟悉 C 语言的肯定知道这是什么意思. 不过如果不熟悉也没关系, 概念也很简单. 假如有一个 python array, 这个 array 里面每个元素都有一个地址, 前后两个元素的地址一般也是连着的, 所以, 我们只要知道这个 array 的第一个元素的地址, 就可以顺势推出所有元素的地址. 这里的 a_ptr b_ptr c_ptr 就是各自 tensor 的第一个元素的地址.
  3. Triton 里面一般都是通过 pointer 这种方式来表示变量. 为什么? 其实这跟 Triton 要解决的问题有关系. Triton 本质上是要追求效率的, 而效率中很重要的一环就是 memory IO, 也就是 data load. 所以 Triton 希望用户能够 be aware of the cost of data loading, 来写出来更加高效的 kernel. 所以, 在 Triton 中, 每个数字 scalar 都需要通过他的 pointer, 手动 load 到 memory 中来. 这个在后面的代码中可以清楚的看到
  4. 输入中还有各种 stride , 这个 stride 是什么意思呢. 我们可以直接从 PyTorch 来看.
a = torch.rand([3,6])
a.stride() 
# (6, 1) 

这里的第一个维度的 stride 是 6, 因为从 a[m, k] 的地址 到 a[m+1, k] 的地址, 中间差了 6 个元素 (具体差了多少 byte 取决于数据类型),

第二个维度的 stride 是 1, 因为从 a[m, k] 的地址 到 a[m, k+1] 的地址, 中间差了 1 个元素.

换句话说, stride 的作用就是为了更加方便的找到每个元素的 pointer (地址)

5. 剩下还有很多 tl.constexpr 这个大家可以理解成这个 kernel 的 超参数 (hyper-parameters), 这个超参数后面可以由 Triton compiler 来进行搜索不同的值. 为什么搜索这些超参数很重要呢? 其实很简单, 因为我们一个 kernel 的输入形状会变, 最终执行这个 kernel 的硬件也有不同规格 (比如不同型号的 GPU 有不同大小的 shared memory size), 所以找到一组合适的超参数是对于优化效率也是很重要的.

+++++++++++++++++++++++++++++++++++++++++

  • 29
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

youzjuer

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值