DeepViT: Towards Deeper Vision Transformer - RE-ATTENTION

DeepViT: Towards Deeper Vision Transformer
code

论文贡献

•我们深入研究了视觉转换器的行为,并观察到它们无法持续受益于将更多层堆叠为CNN。我们进一步确定了这种反直觉现象背后的潜在原因,并首次得出注意力崩溃的结论。

•我们提出了再注意,这是一种简单而有效的注意机制,考虑不同注意头之间的信息交换。
R e − A t t e n t i o n ( Q , K , V ) = N o r m ( Θ T ( S o f t m a x ( Q K T d ) ) ) V Re-Attention(Q,K,V)=Norm(Θ^T(Softmax(\frac{QK^T}{\sqrt{d}})))V ReAttention(Q,K,V)=Norm(ΘT(Softmax(d QKT)))V

•据我们所知,我们是第一家在ImageNet-1k上成功培训32块block的ViT

Attention Collapse

       由于深部CNN的成功,我们对ViTs性能随深度增加的变化进行了系统研究。在不丧失一般性的情况下,我们首先按照[37]中的常见做法,将隐藏维度和头数分别固定为384和12。然后,我们堆叠不同数量的变压器块(从12到32不等),以构建多个对应不同深度的ViT模型。图像分类的总体性能在ImageNet上进行了评估【18】,并在图1中进行了总结。正如性能曲线所证明的那样,我们惊讶地发现,随着模型的深入,分类精度提高得很慢,饱和得很快。更具体地说,我们可以观察到,在使用24个变压器块后,改善停止了。这种现象表明,现有的VIT很难从更深层次的体系结构中获益。
       这样的问题与直觉相悖,值得探索,因为CNN在早期开发阶段也观察到了类似的问题(即如何有效训练更深层次的模型),但后来得到了妥善解决。通过深入研究transfromer体系结构,我们想强调的是,自我注意机制在ViTs中起着关键作用,这使其与CNN显著不同。因此,我们从研究自我注意,或者更具体地说,随着模型的深入,生成的注意力图A也会发生变化。
       为了衡量层间注意图的演变,我们计算了不同层的注意图之间的以下跨层相似性:

M h , t p , q = A h , : , t p T A h , : , t q ∣ ∣ A h , : , t p ∣ ∣   ∣ ∣ A h , : , t q ∣ ∣ M_{h,t}^{p,q}=\frac{ {A_{h,:,t}^p}^T {A^q_{h,:,t}} }{ || {A^p_{h,:,t}}|| \ ||{A^q_{h,:,t}}|| } Mh,tp,q=Ah,:,tp Ah,:,tqAh,:,tpTAh,:,tq

       其中, M p , q M^{p,q} Mp,q是p层和q层注意图之间的余弦相似矩阵。每个元素 M h , t p , q M_{h,t}^{p,q} Mh,tp,q测量头部h和标记t的注意相似性。考虑一个特定的自我注意层及其第h个头部, A h , : , t ∗ A^*_{h,:,t} Ah,:,t是t维向量用于表示输入tonken t对于每个输出T令牌的贡献。因此, M h , t p , q M_{h,t}^{p,q} Mh,tp,q提供了一个关于一个标记的贡献如何从p层到q层变化的适当度量。当 M h , t p , q M_{h,t}^{p,q} Mh,tp,q等于1时,这意味着token t在p层和q层的自我注意中起着完全相同的作用。
       给定等式。(2) 然后,我们在ImageNet-1k上训练了一个包含32个变换块的ViT模型,并研究了所有注意图之间的上述相似性。
在这里插入图片描述
       如图3(a)所示,第17块后,以M为单位的相似注意图比例大于90%。这表明之后学习到的注意图是相似的,变压器块可能退化为MLP。
       因此,进一步叠加此类退化MHSA可能会引入模型秩退化问题(即,将分层参数相乘产生的模型参数张量秩将降低),并限制模型学习能力。我们对学习特征退化的分析也验证了这一点,如下所示。这种观察到的注意力崩溃可能是VIT观察到的表现饱和的原因之一。为了进一步验证不同深度的VIT是否存在这种现象,我们分别对12、16、24和32个变压器块的VIT进行了相同的实验,并计算了具有相似注意图的块数。图3(b)所示的结果清楚地表明,当添加更多transformer块时,相似注意力图块的数量与块总数的比率会增加
在这里插入图片描述
       为了了解注意力崩溃如何影响ViT模型的性能,我们进一步研究了它如何影响更深层次的特征学习。对于特定的32块ViT模型,我们通过研究其余弦相似性,将最终输出特征与每个中间变压器块的输出进行比较。
       图4中的结果表明,相似度非常高,赢得的特征在第20个块之后停止演化。注意相似度的增加与特征相似度的增加有密切的相关性。这一观察结果表明,注意力崩溃是导致VIT不可扩展问题的原因。

模型对比(将原始的self-Attention 换为了 Re-Attention)

在这里插入图片描述

Re-attention 模块

图7:(左):最初的自我注意机制;(右):我们提出的重新关注机制。
如图所示,原始注意力图在与值相乘之前通过可学习矩阵Θ进行混合(self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1))。

效果可视化

在这里插入图片描述

图6:带有32个transformer块的基线ViT模型所选块的注意力地图可视化。
第一行基于原始的自我注意模块,第二行基于重新注意。可以看出,该模型只学习其浅块的局部面片关系,其余的注意值接近于零。虽然他们的注意力范围随着区块的加深而逐渐增大,但注意力地图趋于一致,因此失去了多样性。添加重新注意后,原来相似的注意图将更改为不同的,如第二行所示。只有在最后一个块的注意力图上,才会学习到一个几乎一致的注意力图。

ReAttention 代码

class ReAttention(nn.Module):
    """
    It is observed that similarity along same batch of data is extremely large. 
    Thus can reduce the bs dimension when calculating the attention map.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if apply_transform:
            self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
            self.var_norm = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
            self.reatten_scale = self.scale if transform_scale else 1.0
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x, atten=None):
        B, N, C = x.shape
        # x = self.fc(x)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if self.apply_transform:
            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
        attn_next = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next

ReAttention页面的所有代码

import torch
import torch.nn as nn
import numpy as np
from functools import partial
import torch.nn.init as init
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., expansion_ratio=3):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.act = act_layer()
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., expansion_ratio=3):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.expansion = expansion_ratio
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * self.expansion, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, atten=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn
class ReAttention(nn.Module):
    """
    It is observed that similarity along same batch of data is extremely large. 
    Thus can reduce the bs dimension when calculating the attention map.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if apply_transform:
            self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
            self.var_norm = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
            self.reatten_scale = self.scale if transform_scale else 1.0
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x, atten=None):
        B, N, C = x.shape
        # x = self.fc(x)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if self.apply_transform:
            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
        attn_next = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next
class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, expansion=3, 
                 group = False, share = False, re_atten=False, bs=False, apply_transform=False,
                 scale_adjustment=1.0, transform_scale=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.re_atten = re_atten

        self.adjust_ratio = scale_adjustment
        self.dim = dim
        if  self.re_atten:
            self.attn = ReAttention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 
                expansion_ratio = expansion, apply_transform=apply_transform, transform_scale=transform_scale)
        else:
            self.attn = Attention(
                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 
                expansion_ratio = expansion)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    def forward(self, x, atten=None):
        if self.re_atten:
            x_new, atten = self.attn(self.norm1(x * self.adjust_ratio), atten)
            x = x + self.drop_path(x_new/self.adjust_ratio)
            x = x + self.drop_path(self.mlp(self.norm2(x * self.adjust_ratio))) / self.adjust_ratio
            return x, atten
        else:
            x_new, atten = self.attn(self.norm1(x), atten)
            x= x + self.drop_path(x_new)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x, atten

class PatchEmbed_CNN(nn.Module):
    """ 
        Following T2T, we use 3 layers of CNN for comparison with other methods.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,spp=32):
        super().__init__()

        new_patch_size = to_2tuple(patch_size // 2)

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 112x112
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 112x112
        self.bn2 = nn.BatchNorm2d(64)

        self.proj = nn.Conv2d(64, embed_dim, kernel_size=new_patch_size, stride=new_patch_size)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.proj(x).flatten(2).transpose(1, 2)  # [B, C, W, H]

        return x
class PatchEmbed(nn.Module):
    """ 
        Same embedding as timm lib.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class HybridEmbed(nn.Module):
    """ 
        Same embedding as timm lib.
    """
    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
        super().__init__()
        assert isinstance(backbone, nn.Module)
        img_size = to_2tuple(img_size)
        self.img_size = img_size
        self.backbone = backbone
        if feature_size is None:
            with torch.no_grad():
                training = backbone.training
                if training:
                    backbone.eval()
                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
                feature_size = o.shape[-2:]
                feature_dim = o.shape[1]
                backbone.train(training)
        else:
            feature_size = to_2tuple(feature_size)
            feature_dim = self.backbone.feature_info.channels()[-1]
        self.num_patches = feature_size[0] * feature_size[1]
        self.proj = nn.Linear(feature_dim, embed_dim)

    def forward(self, x):
        x = self.backbone(x)[-1]
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值