Triton 矩阵乘法

对于矩阵相乘:C = A \times B , 矩阵大小 A: M \times K \phantom a B: K\times N \phantom a C:M\times N

图1

考虑简单情况 M 能够被 BLOCK_SIZE_M = 3 整除,K 能够被 BLOCK_SIZE_K = 3 整除,N 能够被 BLOCK_SIZE_N = 3 整除,伪代码如下:

for m in range(0,M,BLOCK_SIZE_M):       
	for n in range(0,N,BLOCK_SIZE_N):   
    	acc = zeros(BLOCK_SIZE_M,BLOCK_SIZE_N,dtype=float32)
    	for k in range(0,K,BLOCK_SIZE_K):
		    a = A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K] # Triton 改写
		    b = B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]
		    acc += matmul(a,b)
	    C[m:m+BLOCK_SIZE_M,n:n+BLOCK_SIZE_N] = acc

Triton 可以很好的把上面的代码进行替换

选 C 矩阵中的 (m,n) 块

两个 for 循环,表示 C 矩阵的分块数目,块与块都可以并行,一共 (M // BLOCK_SIZE_M ) x  (N // BLOCK_SIZE_N) 个,官方代码用一唯 pid 表示具体某个块 (m,n),pid 取值范围 0 - num_pid_m x num_pid_n

for m in range(0,M,BLOCK_SIZE_M):       
	for n in range(0,N,BLOCK_SIZE_N):   

# 用 Triton 替换
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # pid 取值 0 - num_pid_m x num_pid_n 

从 A 和 B 矩阵取两块数据相乘得到块 (m,n)

如图1所示,现在需要从 A、B 矩阵取红色框区域的数据,根据 pid 计算所属行列 (pid_m,pid_n) = (m,n)。代码很抽象,把 C 矩阵中每一整块 BLOCK_SIZE_M x N 大小当成一个 group,这么搞的目的是为了提高 L2 Cache 命中率,但是没必要再弄出来一个 group 的概念,国外的也没说清楚

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

根据 (pid_m,pid_n) 计算地址

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

计算、保存数据

offs_cm[:, None] 跟 torch.Tensor 的用法一样的,Triton 不好调试,可以用 torch 代替

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    # Load the next block of A and B, generate a mask by checking the K dimension.
    # If it is out of bounds, set it to 0.
    a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
    b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
    # We accumulate along the K dimension.
    accumulator = tl.dot(a, b, accumulator)
    # Advance the ptrs to the next K block.
    a_ptrs += BLOCK_SIZE_K * stride_ak
    b_ptrs += BLOCK_SIZE_K * stride_bk


offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)

参考:

https://isamu-website.medium.com/understanding-the-triton-tutorials-part-1-6191b59ba4c

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值