GQA和MQA
Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
上图来自https://arxiv.org/pdf/2305.13245v3.pdf,8个heads,分成4个组
上code镇楼,简单来说
embed_dim=num_heads*head_dim=num_groups*kv_per_group*head_dim
import torch.nn as nn
import torch
from torch import Tensor
import math
import torch.nn.init as init
class CustomLinearModel(nn.Module):
def __init__(self, in_features, out_features):
super(CustomLinearModel, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self._initialize_weights()
def forward(self, x):
return self.linear(x)
# 自定义权重初始化函数
def _initialize_weights(self):
with torch.no_grad():
self.linear.weight.fill_(0.01) # 将权重初始化为0.01
if self.linear.bias is not None:
self.linear.bias.fill_(0.1) # 将偏置初始化为0.1
class MyGQA(nn.Module):
def __init__(self, embed_dim, num_heads, num_groups):
# 就是把num_heads再分成num_groups多组,组内用相当的kv,最后repeat_interleave下就好了
# 例如上面图就是8头分成4个组,每个组内用相同的kv
# embed_dim = num_heads * head_dim
# num_heads = num_groups * kv_per_group
super(MyGQA, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_groups = num_groups
self.kv_per_group = num_heads // num_groups
head_dim = embed_dim // num_heads
# self.q_proj = nn.Linear(embed_dim,embed_dim)
# self.k_proj = nn.Linear(embed_dim,head_dim*num_groups)
# self.w_proj = nn.Linear(embed_dim,head_dim*num_groups)
# self.fc = nn.Linear(embed_dim,embed_dim)
self.q_proj = CustomLinearModel(embed_dim,embed_dim)
self.k_proj = CustomLinearModel(embed_dim,head_dim*num_groups)
self.w_proj = CustomLinearModel(embed_dim,head_dim*num_groups)
self.fc = CustomLinearModel(embed_dim,embed_dim)
self.ln = nn.LayerNorm(embed_dim)
def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor, attn_mask:Tensor):
bs, q_len, head_dim = q.shape
q = q / math.sqrt(head_dim)
# (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
attn = torch.bmm(q, k.transpose(-2, -1))
if attn_mask is not None:
# 这里把true的mask掉
attn = attn.masked_fill(attn_mask == True, float('-inf'))
attn = attn.softmax(dim=-1)
# (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
output = torch.bmm(attn, v)
return output,attn
def forward(self, query:Tensor, key:Tensor, value:Tensor, attn_mask:Tensor=None):
# assert query, key, value have the same shape
bs, q_len, embed_dim = query.shape
head_dim = embed_dim // self.num_heads
q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
k = self.k_proj(key).repeat_interleave(self.kv_per_group, dim=0).reshape(bs, self.kv_per_group, q_len, self.num_groups, head_dim).transpose(2, 3).reshape(bs*self.num_heads, q_len, head_dim)
v = self