Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
上图来自https://arxiv.org/pdf/2305.13245v3.pdf,8个heads,分成4个组
上code镇楼:
import torch.nn as nn
import torch
from torch import Tensor
import math
class MyGQA(nn.Module):
def __init__(self, embed_dim, num_heads, num_groups):
# 就是把num_heads再分成num_groups多组,组内用相当的kv,最后repeat_interleave下就好了
# 例如上面图就是8头分成4个组,每个组内用相同的kv
# embed_dim = num_heads * head_dim
# num_heads = num_groups * kv_per_group
super(MyGQA, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.kv_per_group = num_heads // num_groups
head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim,embed_dim)
self.k_proj = nn.Linear(embed_dim,head_dim*num_groups)
self.w_proj = nn.Linear(embed_dim,head_dim*num_groups)
self.fc = nn.Linear(embed_dim,embed_dim)
self.ln = nn.LayerNorm(embed_dim)
def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor):
bs, q_len, head_dim = q.shape
q = q / math.sqrt(head_dim)
# (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
attn = torch.bmm(q, k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
# (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
output = torch.bmm(attn, v)
return output,attn
def forward(self, query:Tensor, key:Tensor, value:Tensor):
# assert query, key, value have the same shape
bs, q_len, embed_dim = query.shape
head_dim = embed_dim // self.num_heads
q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
k = self.k_proj(key).repeat_interleave(self.kv_per_group,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
v = self.w_proj(value).repeat_interleave(self.kv_per_group,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
self_output,attn = self.scaled_dot_product_attention(q, k, v)
# self_output: bs * num_heads, q_len, head_dim
# attn: bs * num_heads, q_len, k_len
output = self.fc(self_output.reshape(bs, self.num_heads, q_len, head_dim).transpose(1,2).reshape(bs, q_len, self.num_heads*head_dim))
# hugging face版把fc放到BertSelfOutput里去了
return self.ln(output+query),attn
embed_dim,num_heads,num_groups=256,8,4
q_len,bs = 2,3
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
query = torch.ones(bs, q_len, embed_dim)
key = torch.ones(bs, q_len, embed_dim)
value = torch.ones(bs, q_len, embed_dim)
attn_output, attn_output_weights = multihead_attn(query, key, value)
print('attn_output={}'.format(attn_output.shape))
print('attn_output_weights={}'.format(attn_output_weights.shape))
print('--------------')
my_multihead_attn = MyGQA(embed_dim, num_heads, num_groups)
my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value)
print('my_attn_output={}'.format(attn_output.shape))
print('my_attn_output_weights={}'.format(attn_output_weights.shape))
'''
输出如下:
attn_output=torch.Size([3, 2, 100])
attn_output_weights=torch.Size([2, 3, 3])
--------------
my_attn_output=torch.Size([3, 2, 100])
my_attn_output_weights=torch.Size([2, 3, 3])
'''