0x0. 前言
继续补 在GPU上加速RWKV6模型的Linear Attention计算 没有写完的内容,对flash-linear-attention库(https://github.com/sustcsonglin/flash-linear-attention)中的fused_recurrent_rwkv6和chunk_rwkv6的前向实现进行解析,也是对Triton写cuda kernel进行继续学习。这里先解读一下fused_recurrent_rwkv6的实现,chunk_rwkv6的实现后续随缘说。
0x1. fused_recurrent_rwkv6 naive python实现
还是从naive的python实现看起,https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/rwkv6/recurrent_naive.py 。fused_recurrent_rwkv6计算算法对应下面的基础python流程:
def naive_recurrent_rwkv6(
q,
k,
v,
w,
u,
initial_state=None,
output_final_state=False
):
# 记录输入张量 q 的原始数据类型。
orig_dtype = q.dtype
# 将输入张量转换为 32 位浮点数类型。
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
# 获取query张量的形状信息。
batch_size, n_heads, seq_len, d_head_k = q.shape
# 获取值张量的形状信息。
_, _, _, d_head_v = v.shape
# 初始化注意力张量为全零张量,形状为 (B, H, D, D),在 GPU 上进行计算。
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
# 初始化输出张量为全零张量,形状同值张量 v
o = torch.zeros_like(v)
# 如果提供了初始状态 initial_state,则将注意力张量 h 更新为初始状态:
if initial_state is not None:
h += initial_state
# 对序列长度进行迭代,每次迭代处理一个位置的输入:
for i in range(seq_len):
q_i = q[:, :, i, :] # 获取当前位置的query张量。shape为[B, H, D]
k_i = k[:, :, i] # 获取当前位置的key张量。shape为[B, H, D]
v_i = v[:, :, i, :] # 获取当前位置的value张量。shape为[B, H, D]
# 获取当前位置的权重张量,并使用指数函数进行处理。shape为[B, H, D]
w_i = w[:, :, i].exp()
# 计算当前位置的键值乘积,elementwise操作。
# shape变化为[B, H, D, 1] * [B, H, D, 1] -> [B, H, D, 1]
kv_i = k_i[..., None] * v_i[..., None, :]
# 计算当前位置的注意力加权输出,都是elementwise操作。
# h的shape为[B, H, D, D]
# u[None, ..., None]的shape为[1, H, D, 1]
# q_i[..., None]的shape为[B, H, D, 1]
# h + u[None, ..., None] * kv_i 的shape为:
# [B, H, D, D] + [1, H, D, 1] * [B, H, D, 1] ->
# [B, H, D, D] + [B, H, D, 1] ->
# [B, H, D, D]
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
# 将当前位置的输出加入到输出张量中。
# o[:, :, i]的shape为[B, H, D],o_i.sum(-2)的shape为[B, H, D]
o[:, :, i] = o_i.sum(-2)
# 更新注意力张量 h
# h的shape为[B, H, D, D]
# w_i[..., None]的shape为[B, H, D, 1]
# kv_i的shape为[B, H, D, 1]
# h * w_i[..., None] 的shape为[B, H, D, D]也是element-wise操作
h = h * w_i[..., None] + kv_i
return o.to(orig_dtype)
q, k, v, w, u等定义如下:
B = 4 # 批量大小(batch size)为 4。
H = 4 # 头数(number of heads)为 4。
L = 1024 # 序列长度(sequence length)为 1024。
D = 100 # 每个头的维度(dimension)为 100。
dtype = torch.float32 # 定义了张量的数据类型为 32 位浮点数。
# q, k, v 分别是查询(query)、键(key)、值(value)的张量,形状为 (B, H, L, D),
# 使用随机初始化,并且在 GPU 上进行计算。
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, D).cuda().to(dtype)