Pytorch学习笔记:nn.MultiheadAttention——多头注意力机制

Pytorch学习笔记:nn.MultiheadAttention——多头注意力机制

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

功能:创建一个多头注意力模块,参考论文《transformer》,参考论文及源码笔记:https://blog.csdn.net/qq_50001789/article/details/132181971

多头注意力公式为:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O MultiHead(Q,K,V)=Concat(head_1,\dots,head_h)W^O MultiHead(Q,K,V)=Concat(head1,,headh)WO
其中 h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW^Q_i,KW_i^K,VW^V_i) headi=Attention(QWiQ,KWiK,VWiV),流程图如下:

在这里插入图片描述

参数:

  • embed_dim :输入数据的维度,也就是向量的长度;

  • num_heads:表示并行注意力的数量,也就是“头”的数量;

  • dropout:表示注意力权重的丢弃概率,相当于生成注意力之后,再将注意力传入一层Dropout层,默认为0;

  • bias:在做线性变换Linear时,是否添加偏置,默认True

  • add_bias_kvkv做线性变换时是否加偏置,若键值维度与嵌入维度相同,则可以将add_bias_kv设为False,默认False

  • add_zero_attn:将一个全零注意力向量添加到最终的输出中(只影响形状,不改变数值),强制使输出张量的形状与输入张量相同

  • kdim:keys的特征数据维度,即向量长度,默认与embed_dim相等

  • vdim:values的特征数据维度,即向量长度,默认与embed_dim相等

  • batch_first:如果设为True,则输入、输出张量表示为(batch, seq, feature),否则张量表示为(seq, batch, feature),默认False

注意:

  • embed_dim 会被划分成num_heads份,对应的数据也会被划分,传入不同的“head”里,每个“head”的维度是embed_dim // num_heads

前向传播

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)

参数:

  • query, key, value:表示传入的qkv数据,形式因batch_first变量而异,默认(seq, batch, feature),即(序列,batch,特征);

  • key_padding_mask:用于指定哪些位置是填充位置,以便在计算注意力权重时将其忽略。对于batch数据,输入尺寸应为 ( N , S ) (N,S) (N,S),其中 S S S为序列长度,对于非batch数据,输入尺寸应为 S S S,里面的数值可以是布尔、也可以是浮点数。常用布尔数据,True表示该位置为填充,计算注意力的时候需要忽略该位置,如果传入浮点数,则会将该数与key相加,常加负数,用于抑制该位置(False与负无穷效果一样);

  • need_weights:如果指定为True,则网络会额外输出注意力权重;

  • attn_mask:尺寸为 ( L , S ) (L,S) (L,S) ( N , n u m h e a d s , L , S ) (N,num_heads,L,S) (N,numheads,L,S),其中 L L L表示目标序列长度,数值表示位置, S S S表示源序列长度,数值表示位置,如果 attn_mask[b, :, i, j] 为 True,则表示第 b 个样本、第 i 个目标位置和第 j 个源位置之间需要进行注意力计算;

  • average_attn_weights:表示是否要对多头注意力中的权重沿“头”方向做平均,将多组注意力矩阵生成一组矩阵,设为True时,表示需要做平均,即生成一个注意力矩阵,默认True,即生成每个头的注意力矩阵。只有当need_weights设置为True时,该参数才有意义;

  • is_causal:如果 is_causal 为 True,表明目标序列中的每个位置只能依赖于它之前的位置,这个操作能够实现因果性,默认False。这个参数只作为一个提示,最终是否是因果的,还是要看参数attn_mask。

注:

  • 权重由计算k、q的相似度得到,得到的权重再与v相乘,做加权求和;

  • 计算过程:

  先让kqv做线性映射,之后沿特征向量的方向拆分成不同的“头”,之后利用拆分的向量做运算→q和k做矩阵乘法,得到注意力权重→注意力权重除以缩放因子 d k \sqrt{d_k} dk d k d_k dk表示每个头的维度,再做Softmax运算→经过一次Dropout运算(可选)→所得的权重与v做矩阵乘法→合并所有“头”,最后经过一次线性映射;

  • 多头是拆特征,不是拆序列;

多头注意力K、Q、V解释:

  • 目前有多组键值匹配对k、v,每个k对应一个v,计算q所对应的值。思路:计算q与每个k的相似度,得到v的权重,之后对v做加权求和,得到q对应的数值。因此在解码过程中,第二个多头注意力的输入中,k、v传入编码特征(是已知的特征匹配对),q传入解码特征(可迭代传入),求解码对应的特征(根据编码特征之间的相似度求解码的注意力加权特征)。

注:kqv的关系用一句话来说就是根据kv的键值匹配关系,预测q对应的数值,根据kq的相似度对v做加权求和

视频参考:https://www.bilibili.com/video/BV1dt4y1J7ov/?spm_id_from=333.788.recommend_more_video.2&vd_source=b1b1710a3f74753e8bfc47c5c2e4d49e

实现方法

代码来源:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

官方文档

nn.MultiheadAttention:https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=attention#torch.nn.MultiheadAttention

  • 12
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值