Grouped-query Attention(GQA)、Multi-query Attention(MQA)、Multi-Head Latent Attention (MLA)

文章介绍了一种结合了多查询和多头注意力的新型模型——Grouped-queryattention,它在接近多头注意力的速度下实现了更高的质量。作者通过实例展示了如何使用PyTorch实现这一方法,并与标准的多头注意力模块进行比较。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

GQA和MQA

Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
在这里插入图片描述
上图来自https://arxiv.org/pdf/2305.13245v3.pdf,8个heads,分成4个组

上code镇楼,简单来说

embed_dim=num_heads*head_dim=num_groups*kv_per_group*head_dim
import torch.nn as nn
import torch
from torch import Tensor
import math
import torch.nn.init as init

class CustomLinearModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(CustomLinearModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self._initialize_weights()

    def forward(self, x):
        return self.linear(x)

    # 自定义权重初始化函数
    def _initialize_weights(self):
            with torch.no_grad():
                self.linear.weight.fill_(0.01)  # 将权重初始化为0.01
                if self.linear.bias is not None:
                    self.linear.bias.fill_(0.1)  # 将偏置初始化为0.1

class MyGQA(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        # 就是把num_heads再分成num_groups多组,组内用相当的kv,最后repeat_interleave下就好了
        # 例如上面图就是8头分成4个组,每个组内用相同的kv
        # embed_dim = num_heads * head_dim
        # num_heads = num_groups * kv_per_group
        super(MyGQA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.kv_per_group = num_heads // num_groups
        head_dim = embed_dim // num_heads

        # self.q_proj = nn.Linear(embed_dim,embed_dim)
        # self.k_proj = nn.Linear(embed_dim,head_dim*num_groups)
        # self.w_proj = nn.Linear(embed_dim,head_dim*num_groups)
        # self.fc = nn.Linear(embed_dim,embed_dim)

        self.q_proj = CustomLinearModel(embed_dim,embed_dim)
        self.k_proj = CustomLinearModel(embed_dim,head_dim*num_groups)
        self.w_proj = CustomLinearModel(embed_dim,head_dim*num_groups)
        self.fc = CustomLinearModel(embed_dim,embed_dim)

        self.ln = nn.LayerNorm(embed_dim)

    def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor, attn_mask:Tensor):
        bs, q_len, head_dim = q.shape
        q = q / math.sqrt(head_dim)
        # (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
        attn = torch.bmm(q, k.transpose(-2, -1))
        if attn_mask is not None:
            # 这里把true的mask掉
            attn = attn.masked_fill(attn_mask == True, float('-inf'))
        attn = attn.softmax(dim=-1)
        # (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
        output = torch.bmm(attn, v)
        return output,attn

    def forward(self, query:Tensor, key:Tensor, value:Tensor, attn_mask:Tensor=None):
        # assert query, key, value have the same shape
        bs, q_len, embed_dim = query.shape
        head_dim = embed_dim // self.num_heads
        q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        k = self.k_proj(key).repeat_interleave(self.kv_per_group, dim=0).reshape(bs, self.kv_per_group, q_len, self.num_groups, head_dim).transpose(2, 3).reshape(bs*self.num_heads, q_len, head_dim)
        v = self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值