论文及源码笔记:Transformer——Attention Is All You Need

论文及源码笔记:Transformer——Attention Is All You Need

综述

论文题目:《Attention Is All You Need》

时间会议:Advances in Neural Information Processing Systems, 2017 (NIPS, 2017)

论文链接:http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf

介绍

  transformer结构的优点:

  • 长程依赖性处理能力强:自注意力机制可以实现对整张图片进行全局信息的建模
  • 并行化能力强:可以并行计算输入序列中的所有位置;

  网络结构图如下所示:

在这里插入图片描述

其中,多头注意力网络结构为:

在这里插入图片描述

注意力公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V\\ MultiHead(Q,K,V)=Concat(head_1,\dots,head_h)W^O\\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) Attention(Q,K,V)=softmax(dk QKT)VMultiHead(Q,K,V)=Concat(head1,,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)
其中:

  • Q代表query,后续会去和每一个k进行匹配;

  • K代表key,后续会被每个q匹配;

  • V代表输入数据,又称为编码数据Embedding;

  • 后续Q和K匹配的过程可以理解成计算两者的相关性,相关性越大对应v的权重也就越大,主要就是给V生成一组权重;

多头注意力K、Q、V解释:

  目前有多组键值匹配对k、v,每个k对应一个v,计算q所对应的值。思路:计算q与每个k的相似度,得到v的权重,之后对v做加权求和,得到q对应的数值。因此在解码过程中,第二个多头注意力的输入中,k、v传入编码特征(是已知的特征匹配对),q传入解码特征(可迭代传入),求解码对应的特征(根据编码特征之间的相似度求解码的注意力加权特征)。

  注:kqv的关系用一句话来说就是根据kv的键值匹配关系,预测q对应的数值,根据kq的相似度对v做加权求和

视频参考:https://www.bilibili.com/video/BV1dt4y1J7ov/?spm_id_from=333.788.recommend_more_video.2&vd_source=b1b1710a3f74753e8bfc47c5c2e4d49e

多头注意力计算过程:

  先让kqv做线性映射,之后沿特征向量的方向拆分成不同的“头”,之后利用拆分的向量做运算→q和k做矩阵乘法,得到注意力权重→注意力权重除以缩放因子 d k \sqrt{d_k} dk d k d_k dk表示每个头的维度,再做Softmax运算→经过一次Dropout运算(可选)→所得的权重与v做矩阵乘法→合并所有“头”,最后经过一次线性映射;

代码实现

编码模块

class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)

        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self,
                     src,
                     pos: Optional[Tensor] = None):
        # 特征先与位置编码相加
        v = q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, v)[0]
        src = src + src2
        src = self.norm1(src)
        src2 = self.linear2(self.activation(self.linear1(src)))
        src = src + src2
        src = self.norm2(src)
        return src

解码模块

class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self, encoding, decoding, pos: Optional[Tensor] = None):
        q = k = v = self.with_pos_embed(decoding, pos)
        # 解码特征先做一次自注意力运算
        decoding2 = self.self_attn(q, k, v)[0]
        decoding = decoding + decoding2
        decoding = self.norm1(decoding)

        decoding2 = self.multihead_attn(query=decoding, key=encoding, value=encoding)[0]
        decoding = decoding + decoding2
        decoding = self.norm2(decoding)
        decoding2 = self.linear2(self.activation(self.linear1(decoding)))
        decoding = decoding + decoding2
        decoding = self.norm3(decoding)
        return decoding

注:以上仅是笔者个人见解,若有问题,欢迎指正。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值