理论知识链接:理解Attention:从起源到MHA,MQA和GQA | Linsight
现有模型升级方法:https://blog.nghuyong.top/2023/09/10/NLP/llm-attention/
论文图片解释,区别在于KV的总特征大小不同,每个方块都是head_dim维度大小
pytorch代码实现:
class BaseAttention(torch.nn.Module):
def __init__(self):
super(BaseAttention, self).__init__()
self.softmax = torch.nn.Softmax(dim=-1)
def attention(self, q, k, v, mask=None, dropout=None):
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
if mask is not None:
attn = at