Grouped-query Attention(GQA)、Multi-query Attention(MQA)

文章介绍了一种结合了多查询和多头注意力的新型模型——Grouped-queryattention,它在接近多头注意力的速度下实现了更高的质量。作者通过实例展示了如何使用PyTorch实现这一方法,并与标准的多头注意力模块进行比较。
摘要由CSDN通过智能技术生成

Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
在这里插入图片描述
上图来自https://arxiv.org/pdf/2305.13245v3.pdf,8个heads,分成4个组

上code镇楼:

import torch.nn as nn
import torch
from torch import Tensor
import math

class MyGQA(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        # 就是把num_heads再分成num_groups多组,组内用相当的kv,最后repeat_interleave下就好了
        # 例如上面图就是8头分成4个组,每个组内用相同的kv
        # embed_dim = num_heads * head_dim
        # num_heads = num_groups * kv_per_group
        super(MyGQA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.kv_per_group = num_heads // num_groups
        head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim,embed_dim)
        self.k_proj = nn.Linear(embed_dim,head_dim*num_groups)
        self.w_proj = nn.Linear(embed_dim,head_dim*num_groups)
        self.fc = nn.Linear(embed_dim,embed_dim)
        self.ln = nn.LayerNorm(embed_dim)

    def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor):
        bs, q_len, head_dim = q.shape
        q = q / math.sqrt(head_dim)
        # (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
        attn = torch.bmm(q, k.transpose(-2, -1))
        attn = attn.softmax(dim=-1)
        # (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
        output = torch.bmm(attn, v)
        return output,attn

    def forward(self, query:Tensor, key:Tensor, value:Tensor):
        # assert query, key, value have the same shape
        bs, q_len, embed_dim = query.shape
        head_dim = embed_dim // self.num_heads
        q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        k = self.k_proj(key).repeat_interleave(self.kv_per_group,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        v = self.w_proj(value).repeat_interleave(self.kv_per_group,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        self_output,attn = self.scaled_dot_product_attention(q, k, v)
        # self_output: bs * num_heads, q_len, head_dim
        # attn: bs * num_heads, q_len, k_len
        output = self.fc(self_output.reshape(bs, self.num_heads, q_len, head_dim).transpose(1,2).reshape(bs, q_len, self.num_heads*head_dim))
        # hugging face版把fc放到BertSelfOutput里去了
        return self.ln(output+query),attn

embed_dim,num_heads,num_groups=256,8,4
q_len,bs = 2,3

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
query = torch.ones(bs, q_len, embed_dim)
key = torch.ones(bs, q_len, embed_dim)
value = torch.ones(bs, q_len, embed_dim)

attn_output, attn_output_weights = multihead_attn(query, key, value)
print('attn_output={}'.format(attn_output.shape))
print('attn_output_weights={}'.format(attn_output_weights.shape))
print('--------------')

my_multihead_attn = MyGQA(embed_dim, num_heads, num_groups)
my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value)
print('my_attn_output={}'.format(attn_output.shape))
print('my_attn_output_weights={}'.format(attn_output_weights.shape))

'''
输出如下:
attn_output=torch.Size([3, 2, 100])
attn_output_weights=torch.Size([2, 3, 3])
--------------
my_attn_output=torch.Size([3, 2, 100])
my_attn_output_weights=torch.Size([2, 3, 3])
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值