【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

paper: https://arxiv.org/pdf/2105.02358v2
code: https://paperswithcode.com/paper/beyond-self-attention-external-attention


简介:

摘要:

注意力机制,尤其是自注意力,在视觉任务的深度特征表示中发挥着越来越重要的作用。Self-attention 通过使用跨所有位置的成对相似性计算特征的加权和来更新每个位置的特征,以捕获单个样本中的长期依赖性。然而,self-attention 具有二次复杂度,忽略了不同样本之间的潜在相关性。本文提出了一种新的注意力机制,我们称之为外部注意力,基于两个外部的、小的、可学习的、共享的记忆,可以通过简单地使用两个级联线性层和两个归一化层轻松实现;它方便地取代了现有流行架构中的自我注意。外部注意力具有线性复杂度,隐含地考虑了所有数据样本之间的相关性。我们进一步将多头机制纳入外部注意力,为图像分类提供全 MLP 架构、外部注意力 MLP (EAMLP)。在图像分类、目标检测、语义分割、实例分割、图像生成和点云分析上的大量实验表明,我们的方法提供了与自我注意机制及其一些变体相当或更好的结果,计算和内存成本要低得多

结论:

本文介绍了外部注意力,这是一种新颖但有效的注意力机制,可用于各种视觉任务。外部注意力中采用的两个外部存储器单元可以被视为整个数据集的字典,并且能够在降低计算成本的同时学习更多具有代表性的输入特征。我们希望外部注意力将激发实际应用和研究其在 NLP 等其他领域的使用。


external-attention 结构图

在这里插入图片描述

The computational complexity of external attention is O(dSN ); as d and S are hyper-parameters, the proposed algorithm is linear in the number of pixels. In fact, we find that a small S, e.g. 64, works well in experiments. Thus, external attention is much more efficient than selfattention, allowing its direct application to large-scale inputs.

使用方式:大概在最后一层

在这里插入图片描述

代码:

在这里插入图片描述

官方代码:

# from: https://github.com/MenghaoGuo/EANet/blob/main/model_torch.py

class External_attention(nn.Module):
    '''
    Arguments:
        c (int): The input and output channel number.
    '''
    def __init__(self, c):
        super(External_attention, self).__init__()
        
        self.conv1 = nn.Conv2d(c, c, 1)

        self.k = 64
        self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

        self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
        self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            norm_layer(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
    def forward(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = F.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x)
        return x

实现多头 注意力:

在这里插入图片描述

在这里插入图片描述
官方代码:

# from: https://github.com/MenghaoGuo/EANet/blob/main/multi_head_attention_torch.py

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        assert dim % num_heads == 0 
        self.coef = 4
        self.trans_dims = nn.Linear(dim, dim * self.coef)        
        self.num_heads = self.num_heads * self.coef
        self.k = 256 // self.coef
        self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
        self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
        

        self.attn_drop = nn.Dropout(attn_drop)        
        self.proj = nn.Linear(dim * self.coef, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        x = self.trans_dims(x) # B, N, C 
        x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        
        attn = self.linear_0(x)

        attn = attn.softmax(dim=-2)
        attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))
        attn = self.attn_drop(attn)
        x = self.linear_1(attn).permute(0,2,1,3).reshape(B, N, -1)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

一个集成 模块的代码仓库的代码(非官方):

# from: https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/ExternalAttention.py

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ExternalAttention(nn.Module):

    def __init__(self, d_model,S=64):
        super().__init__()
        self.mk=nn.Linear(d_model,S,bias=False)
        self.mv=nn.Linear(S,d_model,bias=False)
        self.softmax=nn.Softmax(dim=1)
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries):
        attn=self.mk(queries) #bs,n,S
        attn=self.softmax(attn) #bs,n,S
        attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S
        out=self.mv(attn) #bs,n,d_model

        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    ea = ExternalAttention(d_model=512,S=8)
    output=ea(input)
    print(output.shape)

    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值