2021 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》 Pytorch实现

import torch
from torch import nn
from torch.nn import init


# External Attention
# 外部注意力
# 方法出处 2021 arxiv 《Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks》
class ExternalAttention(nn.Module):
    # 网络层的初始化
    def __init__(self, d_model, S=64):
        # 所有继承于nn.Module的模型都要写这句话
        super(ExternalAttention, self).__init__()
        # 外部记忆单元1
        self.mk = nn.Linear(d_model, S, bias=False)
        # 外部记忆单元2
        self.mv = nn.Linear(S, d_model, bias=False)
        # softmax层
        self.softmax = nn.Softmax(dim=1)
        # 网络层权重初始化
        self.init_weights()

    def init_weights(self):
        # 遍历当前模型所有的层
        for m in self.modules():
            # 如果是卷积层
            if isinstance(m, nn.Conv2d):
                # kaiming初始化
                init.kaiming_normal_(m.weight, mode='fan_out')
                # 偏置初始化为0
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            # 如果是正则化层
            elif isinstance(m, nn.BatchNorm2d):
                # 权重为1,偏置为0
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            # 如果是线性层
            elif isinstance(m, nn.Linear):
                # 正态分布初始化
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries):
        attn = self.mk(queries)  # bs,n,S
        # 沿着第一个维度n进行softmax
        # 对于每一个切片矩阵n*s
        # 的每一列进行softmax
        # 相当于捕获不同样本的相似性
        attn = self.softmax(attn)  # bs,n,S
        # 这步相当于正则化
        # toch.sum(attn,dim=2,keepdim=True)
        # 是对于attn沿着第二维度相加
        # 输出结果维度是[bs,n,1]
        # 通过广播机制去除attn
        # 相当于对于attn的每一行进行softmax
        attn = attn / torch.sum(attn, dim=2, keepdim=True)  # bs,n,S
        out = self.mv(attn)  # bs,n,d_model

        return out


if __name__ == '__main__':
    input = torch.randn(50, 49, 512)
    ea = ExternalAttention(d_model=512, S=8)
    output = ea(input)
    print(output.shape)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_Old_Summer

感谢老板!!!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值