VOLO(TAPMI 2022, Sea)论文与代码解析

paper:VOLO: Vision Outlooker for Visual Recognition

official implementation:https://github.com/sail-sg/volo

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/volo.py

背景

视觉识别领域长期以来由卷积神经网络(CNNs)主导。尽管最近流行的视觉Transformer(ViTs)在ImageNet分类中展示了基于自注意力模型的巨大潜力,但如果没有额外数据,ViTs的性能仍然不如最新的SOTA(最先进)CNNs。

出发点

作者发现ViTs在ImageNet分类中的主要限制因素是其在将细粒度特征编码到token表示中的低效性。为了解决这个问题,本文引入了一种新的Outlook注意力机制,并提出了一种简单且通用的架构,称为Vision Outlooker(VOLO)。

创新点

  • Outlook注意力机制:与侧重于全局依赖建模的自注意力不同,Outlook注意力有效地将细粒度特征和上下文信息编码到tokens中。
  • 两阶段架构设计:VOLO首先对图像进行细粒度的tokenization,然后使用多个Outlookers在细粒度级别上编码token表示,接着在粗粒度级别上构建全局依赖关系。
  • 高效的注意力权重生成:Outlook注意力通过简单的线性投影生成注意力权重,避免了自注意力机制中昂贵的点积计算。

效果

  • ImageNet-1K分类:VOLO在ImageNet-1K分类上实现了87.1%的top-1准确率,这是首个在该基准上超过87%准确率的模型,且未使用任何额外的训练数据。
  • 迁移性能:预训练的VOLO在下游任务(如语义分割)中表现出色,在Cityscapes验证集上取得了84.3%的mIoU得分,在ADE20K验证集上取得了54.3%的mIoU得分。
  • 模型效率:与之前的SOTA模型相比,VOLO使用更少的训练参数但达到了更高的准确率。

方法介绍

Outlook attention的启发包括两点:1)每个空间位置的特征具有足够的代表性来生成注意力权重,在局部聚合其相邻特征。2)密集、局部的空间聚合可以高效地编码fine-level的信息。

对于每个空间位置 \((i,j)\),outlook attention计算它与以 \((i,j)\) 为中心大小为 \(K\times K\) 的局部窗口内所有相邻位置的相似性。和self-attention需要Query-Key矩阵相乘来计算注意力不同,outlook attention通过一个reshape操作简化了该过程。

给定输入 \(\mathbf{X}\),每个 \(C\) 维的token通过两个线性层 \(\mathbf{W}_A\in \mathbb{R}^{C\times K^4}\) 和 \(\mathbf{W}_V\in \mathbb{R}^{C\times C}\) 映射得到outlook权重 \(\mathbf{A}\in \mathbb{R}^{H\times W\times K^4}\) 和value表示 \(\mathbf{V}\in \mathbb{R}^{H\times W\times C}\)。我们用 \(\mathbf{V}_{\Delta i,j}\in \mathbb{R}^{C\times K^2}\) 表示以 \((i,j)\) 为中心的局部窗口内的所有值,即

位置 \((i,j)\) 处的outlook权重直接当做注意力权重用来聚合value,通过将其reshape成 \(\mathbf{\hat{A}}_{i,j}\in \mathbb{R}^{K^2\times K^2}\) 然后接一个Softmax函数。value映射过程表示如下

outlook attention密集地聚合映射后的value表示。将来自不同局部窗口在同一位置处的权重相加得到最终输出

outlook attention的整体过程如图2所示,Pytorch代码如下所示,式(3)和(5)分别对应Unfold和Fold操作。

Multi-head的实现也很简单,假设head数量为 \(N\),我们只需要调整 \(\mathbf{W}_A\) 的shape得到 \(\mathbf{W}_A\in \mathbb{R}^{C\times N\cdot K^4}\)。然后outlook权重和value embedding也均匀分成 \(N\) 份,得到 \(\mathbf{A}_n\in\mathbb{R}^{H\times W\times K^4}\) 和 \(\mathbf{V}_n\in\mathbb{R}^{H\times W\times C_N},\{n=1,2,..,N\}\),其中 \(C_N\) 是每个head的维度且满足 \(C_N\times N=C\)。对于每对 \((\mathbf{A}_n,\mathbf{V}_n)\) 分别计算outlook attention,然后将结果concat起来得到最终输出。

实验结果

作者基于LV-ViT(具体介绍见Token Labeling(NeurIPS 2021, ByteDance)论文解读-CSDN博客)和outlook attention构建了Vision Outlooker(VOLO),为了得到更细粒度的token表示,第一个stage中调整了patch embedding层使得patch大小为8x8而不是16x16,然后用多个outlooker来在更精细的level生成更expressive的token表示。stage2-stage4采用传统的transformer block,每个stage开始用一个patch embedding层来下采样降低分辨率。作者构建了多个不同尺度的VOLO variants,具体配置如下

在ImageNet数据集上和其它SOTA模型的性能对比如下表所示,可以看到对于不同的模型大小,VOLO都得到了最优的性能。其中在224x224的输入上训练的VOLO-D5取得了86.1%的top-1精度,进一步在512x512大小的输入上微调,精度达到了87.1%,这是第一个不需要额外训练数据在ImageNet上超过87%精度的模型。 

作者进一步在Cityscapes和ADE20K数据集上测试了VOLO作为预训练backbone的语义分割性能,结果如下,可以看到VOLO在两个数据集上都取得了新的SOTA表现。

代码解析

这里以timm中的实现为例,模型选择"volo_d1_224",输入大小为(1, 3, 224, 224)。OutlookAttention的代码如下,其中self.v就是 \(\mathbf{W}_V\)。self.unfold对应式(3)(unfold的具体用法可参考torch.nn.functional.unfold 用法解读-CSDN博客)。self.attn就是 \(\mathbf{W}_A\),50和53行对应式(4),最后F.fold操作对应式(5),F.fold就是unfold的逆过程,具体可参考「详解」torch.nn.Fold和torch.nn.Unfold操作_torch.unfold-CSDN博客,fold过程中重叠位置的值会相加,即上文提到的“将来自不同局部窗口在同一位置处的权重相加得到最终输出”。

class OutlookAttention(nn.Module):

    def __init__(
        self,
        dim,  # 192
        num_heads,  # 6
        kernel_size=3,
        padding=1,
        stride=1,
        qkv_bias=False,
        attn_drop=0.,
        proj_drop=0.,
    ):
        super().__init__()
        head_dim = dim // num_heads  # 32
        self.num_heads = num_heads  # 6
        self.kernel_size = kernel_size  # 3
        self.padding = padding  # 1
        self.stride = stride  # 2
        self.scale = head_dim ** -0.5

        self.v = nn.Linear(dim, dim, bias=qkv_bias)  # W_{V}
        self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)  # W_{A}, 192,3^4*6=486

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
        self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)

    def forward(self, x):  # (1,28,28,192)
        B, H, W, C = x.shape
        # Linear(in_features=192, out_features=192, bias=False)
        v = self.v(x).permute(0, 3, 1, 2)  # B, C, H, W; (1,192,28,28)

        h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)  # 14,14
        v = self.unfold(v).reshape(
            B, self.num_heads, C // self.num_heads,
            self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2)  # B,H,N,kxk,C/H
        # (1,1728,196)->(1,6,32,9,196)->(1,6,196,9,32)

        attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)  # (1,192,28,28)->(1,192,14,14)->(1,14,14,192)
        # Linear(in_features=192, out_features=486, bias=True)
        attn = self.attn(attn).reshape(
            B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
            self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)  # B,H,N,kxk,kxk
        # (1,14,14,486)->(1,196,6,9,9)->(1,6,196,9,9)
        attn = attn * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)
        # (1,6,196,9,32)->(1,6,32,9,196)->(1,1728,196)
        x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)  # (1,192,28,28)
        # Linear(in_features=192, out_features=192, bias=True)
        x = self.proj(x.permute(0, 2, 3, 1))  # (1,28,28,192)
        x = self.proj_drop(x)

        return x
  • 22
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值