Day 3 second: Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks

总结

  • self-attention 的计算复杂度会呈二次方增长,因此实用性不大
  • 本文基于两个外部的、小尺寸的、可学习的,且共享的存储,提出 external attention,可以方便地在现有的流行模型中,替换掉self-attention结构。
  • external attention 有线性的计算复杂度,同时它考虑到了 所有样本 之间的相关性

Core Ideas and Contribution

  • 首先,通过计算 self query vector 和 external learnable key memory 之间的近似性得到注意力图,再将得到的注意力图和另一个 external learnable value memory 相乘,得到一个精炼过的特征图。
  • 这两个memory都是用线性层(linear layer)实现的。
  • 它们和单个的样本之间互相独立,但是在整个数据集之间共享。
  • 本结构能做到轻量化的最主要原因,是因为在这两个 external memory 里的参数要远小于输入特征中的参数。
  • external memory 设计成了用来学习整个数据集中最有区分度的特征,捕捉最有信息量的部分,以及排除掉其它样本中具有干涉/迷惑性的信息。

Methods and Approaches

目前常用的 self-attention 模块

先来看一下普通的 self-attention 模块是怎样操作的:

给到一个尺寸为 F ∈ R N × d F \in \mathbb{R}^{N \times d} FRN×d的输入,此处N为像素的数量,d为特征维度(feature dimensions)的数量,普通的self-attention首先将会把输入线性地映射到一个query矩阵 Q ∈ R N × d ′ Q \in \mathbb{R}^{N \times d^{\prime}} QRN×d,一个key矩阵 K ∈ R N × d ′ K \in \mathbb{R}^{N \times d^{\prime}} KRN×d,和一个value矩阵 V ∈ R N × d V \in \mathbb{R}^{N \times d} VRN×d。接下来,用以下的式子得到最终的结果:

A = ( α ) i , j = softmax ⁡ ( Q K T ) F out  = A V \begin{aligned}A &=(\alpha)_{i, j}=\operatorname{softmax}\left(Q K^{T}\right) \\F_{\text {out }} &=A V\end{aligned} AFout =(α)i,j=softmax(QKT)=AV

上式中的 A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N即为注意力矩阵,矩阵中的 α i , j \alpha_{i, j} αi,j是第i个像素点和第j个像素点之间的相似度

由此看来,之前的工作基本都是用的图片块(image patch),而不是所有的像素,是因为不这样的话对计算力的要求实在太大

当把注意力图给可视化出来后,注意到大部分的像素其实只与很少的几个像素之间有很强的相关性,因此一个 N x N 的注意力矩阵实在太过于冗余了
在这里插入图片描述

External Attention

于是,作者提出以下的方法,用external memory来计算注意力矩阵,计算得到的是输入像素与这个external memory 之间的attention。external memory 的尺寸是 M ∈ R S × d M \in \mathbb{R}^{S \times d} MRS×d。此处 ( α ) i , j (\alpha)_{i, j} (α)i,j代表的是第 i i i个像素和记性模块 M M M j j j行的相似度, M M M是个可学习的参数,且与输入独立(不相关),充当的是整个数据集的记忆模块.

A = ( α ) i , j = Norm ⁡ ( F M T ) F out  = A M \begin{aligned}A &=(\alpha)_{i, j}=\operatorname{Norm}\left(F M^{T}\right) \\F_{\text {out }} &=A M\end{aligned} AFout =(α)i,j=Norm(FMT)=AM

实际使用中,我们用的是两个不同的memory模块,称为 M k M_k Mk M v M_v Mv,前者是key,后者是value,以达到提升网络能力的目的,计算如下:

A = Norm ⁡ ( F M k T ) F out  = A M v \begin{aligned}A &=\operatorname{Norm}\left(F M_{k}^{T}\right) \\F_{\text {out }} &=A M_{v}\end{aligned} AFout =Norm(FMkT)=AMv

这样,我们的算法就是和像素的数量呈线性相关的了,复杂度为 O ( d S N ) \mathcal{O}(d S N) O(dSN),其中 d d d S S S为超参。而且试验中发现,即使S的值很少,比如设为64,也有很好的效果。

Python pesudo-code for external attention

# Input: F, an array with shape [B, N, C] (batch size, pixels, channels)
# Parameter: M_k, a linear layer without bias
# Parameter: M_v, a linear layer without bias
# Output: out, an array with shape [B, N, C]
attn = M_k(F) # shape=(B, N, M)
attn = softmax(attn, dim=1)
attn = l1_norm(attn, dim=2)
out = M_v(attn) # shape=(B, N, C)

Normalization

  • The attention calculated by matrix multiplication is sensitive to the scale of input features, thus need to be normalized, we use double-normalization here, which seperately normalize columns and rows.
  • Softmax is used here.

( α ~ ) i , j = F M k T α i , j = exp ⁡ ( α ~ i , j ) ∑ k exp ⁡ ( α ~ k , j ) α i , j = α i , j ∑ k α i ^ , k \begin{aligned}(\tilde{\alpha})_{i, j} &=F M_{k}^{T} \\\alpha_{i, j} &=\frac{\exp \left(\tilde{\alpha}_{i, j}\right)}{\sum_{k} \exp \left(\tilde{\alpha}_{k, j}\right)} \\\alpha_{i, j} &=\frac{\alpha_{i, j}}{\sum_{k} \alpha_{\hat{i}, k}}\end{aligned} (α~)i,jαi,jαi,j=FMkT=kexp(α~k,j)exp(α~i,j)=kαi^,kαi,j

Official Code

https://github.com/MenghaoGuo/-EANet

class External_attention(nn.Module):
    '''
    Arguments:
        c (int): The input and output channel number.
    '''
    def __init__(self, c):
        super(External_attention, self).__init__()
        
        self.conv1 = nn.Conv2d(c, c, 1)

        self.k = 64
        self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

        self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
        self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            norm_layer(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
 

    def forward(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = F.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x)
        return x
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值