T2T-ViT(ICCV 2021)论文与代码解析

paper:Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

official implementation:https://github.com/yitu-opensource/T2T-ViT

third-party implementation:https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/t2t_vit.py

存在的问题

Vision Transformer 需要大量的数据集(比如JFT-300M)才能得到超越CNN的效果,当在中等规模的数据集(比如ImageNet)上从零开始训练时性能不如CNN。 作者发现,这是因为ViT在训练时存在两个主要问题:简单的图像切分无法有效捕捉局部结构(如边缘和线条),以及ViT的冗余注意力机制设计导致特征丰富性不足,特别是在计算预算和训练样本有限的情况下。

创新点

本文解决了ViT在中等规模数据集上训练时性能不如CNN的问题。为此,作者提出了一种新的Tokens-to-Token Vision Transformer (T2T-ViT) 模型,通过改进图像切分和注意力机制设计来提升模型的训练效率和性能。具体包括:

  1. 逐层Tokens-to-Token(T2T)转换:通过递归地将相邻的Tokens聚合成一个Token,逐步结构化图像,以便更好地捕捉局部结构信息并减少Tokens长度。
  2. 高效的骨干设计:借鉴CNN架构设计,采用“深而窄”的结构,提高特征丰富性并减少冗余。

效果

T2T-ViT 在 ImageNet 数据集上从头开始训练时,与具有相似参数量的 ResNet 和 MobileNet 相比,取得了更好的性能。例如,与 ResNet50(25.5M 参数)相比,T2T-ViT(21.5M 参数)在 384×384 图像分辨率下达到了 83.3% 的 top-1 准确率,而 ResNet50 的准确率在 76%-79% 之间。此外,T2T-ViT 还成功地将 ViT 的参数数量和MACs(Multiply-Accumulate Operations)减少了一半,同时实现了超过 3.0% 的性能提升。

方法介绍

为了克服ViT简单的tokenization的局限性以及低效的backbone架构,作者提出了Tokens-to-Token Vision Transformer(T2T-ViT)。如图4所示,T2T-ViT有两个主要的组成部分:1)一个逐层的“Tokens-to-Token module(T2T module)”用来对图像的局部结构信息进行建模,并逐步减少tokens的长度。2)一个高效的“T2T-ViT backbone”从T2T module中提取tokens的全局注意力关系。在探索了几种不同的CNN架构后,采用了一种深而窄deep-narrow的骨干结构来减少冗余,提高特征丰富性。

Tokens-to-Token: Progressive Tokenization

每个T2T过程包括两步:Re-structurization Soft Split,如图3所示

Re-structurization 如图3所示,给定一个来自前一个transformer layer的tokens序列 \(T\),它将通过self-attention block(图3中的T2T transformer)进行转换:

其中 MSA 表示多头自注意力操作,MLP表示多层感知器,两者都采用layer normalization。然后tokens \(T'\) 被reshape成空间维度中的图像

其中“Reshape”将tokens \(T'\in\mathbb{R}^{l\times c}\) 转换成 \(I\in\mathbb{R}^{h\times w\times c}\),其中 \(l\) 是tokens \(T'\) 的长度,\(h,w,c\) 分别表示高、宽和通道数且 \(l=h\times w\)。

Soft Split 如图3所示,在得到重构后的图像 \(I\) 后,我们对其进行 soft split 来对局部结构信息进行建模并减小tokens的长度。具体来说,为了避免从重构后的图像中生成tokens时丢失信息,我们将其split成重叠的patch。这样每个patch都和周围的patch相关联,从而建立了一个先验prior:相邻的tokens之间应该具有强相关性。然后每个patch的token拼接到一起得到一个token,从而可以从周围的像素和patch中聚合局部信息。

当进行软分割时,每个patch的大小为 \(k\times k\),patch之间的重叠为 \(s\),图片的padding为 \(p\),其中 \(k-s\) 和卷积中的stride类似。所以对于重构的图片 \(I\in\mathbb{R}^{h\times w\times c}\),soft split后的输出token \(T_o\) 的长度为

每个split patch的大小为 \(k\times k\times c\),我们将所有的patch沿空间维度展平得到 \(T_o\in\mathbb{R}^{l_o\times ck^2}\)。在soft split后,输出进入到下一个T2T。

T2T module 通过多次迭代上述Re-structurization和Soft Split,T2T module可以逐步减少token的长度,变换图像的空间结构。T2T module的迭代过程如下

对于输入图片 \(I_0\),我们首先通过一个soft split得到tokens:\(T_1=SS(I_0)\)。在迭代多次后最后一个iteration,T2T module的输出token \(T_f\) 具有一个固定长度,所以T2T-ViT的backbone可以通过 \(T_f\) 建模全局关系。

此外,由于T2T module中的token长度大于ViT中的一般情况(16x16),因此MACs和内存开销都更大。为了解决这个问题,作者将T2T module中T2T layer的通道数设置的较小(32或64)来减小MACs。

T2T-ViT Backbone

作者探索了CNN中五种不同的结构设计,具体包括:

  1. DenseNet中的Dense connection
  2. Deep-narrow结构和Wide-ResNets中的shallow-wide结构
  3. 通道注意力比如SE
  4. 多头注意力中分配更多的head,像ResNeXt一样
  5. GhostNet中的Ghost operation

通过实验,作者观察到:1)在ViT中采用deep-narrow结构减少了通道冗余,增强深度提高了特征丰富性,降低模型尺寸和MACs的同时模型性能还提高了。2)SE block对ViT的性能提升也有帮助,但不如deep-narrow大。基础这些发现,作者将T2T-ViT设计成deep-narrow的结构,具体来说,它有一个较小的通道数和隐藏层的维度 \(d\),但有更多的层 \(b\)。对于T2T module最后一层的输出具有固定长度的token \(T_f\),我们将其与一个class token拼接到一起,并加上位置编码,就像ViT中做的一样

其中 \(E\) 是正弦位置编码,\(LN\) 是layer normalization,\(fc\) 是全连接层,\(y\) 是预测输出。

T2T-ViT Architecture

如图4所示,T2T-ViT由两部分组成,T2T module和T2T-ViT backbone。对于T2T模块有多种设计选择,如图4所示这里我们选择 \(n=2\),这意味着T2T module中有 \(n+1=3\) 个soft split和 \(n=2\) 个re-structurization。三个soft split中的patch size为 \(P=\{7,3,3\}\),重叠大小 \(S=[3,1,1]\),根据式(3)这将输入图片从224x224减小到14x14。

为了和CNN公平比较,我们让T2T-ViT的大小与ResNets和MobileNets相当,具体来说,我们设计了三种模型:T2T-ViT-14、T2T-ViT-19、T2T-ViT-24,参数量分别和ResNet-50、ResNet-101、ResNet-152相当。此外还设计了两个轻量模型:T2T-ViT-7和T2T-ViT-12,参数量分别和MobileNetV1和MobileNetV2相当。具体的参数配置如表1所示

实验结果

在ImageNet上和ViT以及DeiT的结果对比如表2所示,可以看到在大约相同的参数量下,T2T-ViT的精度超过了ViT和DeiT。

和ResNet系列的对比如表3,T2T-ViT的精度也更高。

代码解析

这里的代码是MMPretrain中的实现。其中soft split通过nn.Unfold实现,具体介绍见torch.nn.functional.unfold 用法解读-CSDN博客。一共包含3次soft split和2次re-structurization,输入大小为(1, 3, 224, 224),经过第一个soft split和re-structurization后得到输出(1, 64, 56, 56),经过第二次soft split和re-structurization后得到输出(1, 64, 28, 28),经过第三次soft split后得到输出(1, 196, 576),其中196=14x14,576=3x3x64,注意这次没有re-structurization了,然后就进入了self.project。

class T2TModule(BaseModule):
    """Tokens-to-Token module.

    "Tokens-to-Token module" (T2T Module) can model the local structure
    information of images and reduce the length of tokens progressively.

    Args:
        img_size (int): Input image size
        in_channels (int): Number of input channels
        embed_dims (int): Embedding dimension
        token_dims (int): Tokens dimension in T2TModuleAttention.
        use_performer (bool): If True, use Performer version self-attention to
            adopt regular self-attention. Defaults to False.
        init_cfg (dict, optional): The extra config for initialization.
            Default: None.

    Notes:
        Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
        MACs
    """

    def __init__(
        self,
        img_size=224,
        in_channels=3,
        embed_dims=384,
        token_dims=64,
        use_performer=False,
        init_cfg=None,
    ):
        super(T2TModule, self).__init__(init_cfg)

        self.embed_dims = embed_dims

        self.soft_split0 = nn.Unfold(
            kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
        self.soft_split1 = nn.Unfold(
            kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.soft_split2 = nn.Unfold(
            kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

        if not use_performer:
            self.attention1 = T2TTransformerLayer(
                input_dims=in_channels * 7 * 7,
                embed_dims=token_dims,
                num_heads=1,
                feedforward_channels=token_dims)

            self.attention2 = T2TTransformerLayer(
                input_dims=token_dims * 3 * 3,
                embed_dims=token_dims,
                num_heads=1,
                feedforward_channels=token_dims)

            self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
        else:
            raise NotImplementedError("Performer hasn't been implemented.")

        # there are 3 soft split, stride are 4,2,2 separately
        out_side = img_size // (4 * 2 * 2)
        self.init_out_size = [out_side, out_side]
        self.num_patches = out_side**2

    @staticmethod
    def _get_unfold_size(unfold: nn.Unfold, input_size):
        h, w = input_size  # 224,224
        kernel_size = to_2tuple(unfold.kernel_size)  # (7,7)
        stride = to_2tuple(unfold.stride)  # (4,4)
        padding = to_2tuple(unfold.padding)  # (2,2)
        dilation = to_2tuple(unfold.dilation)  # (1,1)

        h_out = (h + 2 * padding[0] - dilation[0] *
                 (kernel_size[0] - 1) - 1) // stride[0] + 1
        # (224 + 2x2 - 1x(7-1) - 1) // 4 + 1 = 56
        w_out = (w + 2 * padding[1] - dilation[1] *
                 (kernel_size[1] - 1) - 1) // stride[1] + 1
        return (h_out, w_out)

    def forward(self, x):  # (1,3,224,224)
        # step0: soft split
        hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])  # (56,56)
        x = self.soft_split0(x).transpose(1, 2)  # (1,147,3136)->(1,3136,147), 3136=56x56

        for step in [1, 2]:
            # re-structurization/reconstruction
            attn = getattr(self, f'attention{step}')
            x = attn(x).transpose(1, 2)  # (1,3136,147)->(1,3136,64)->(1,64,3136)
            B, C, _ = x.shape
            x = x.reshape(B, C, hw_shape[0], hw_shape[1])  # (1,64,56,56)

            # soft split
            soft_split = getattr(self, f'soft_split{step}')
            hw_shape = self._get_unfold_size(soft_split, hw_shape)  # (28,28)
            x = soft_split(x).transpose(1, 2)  # (1,576,784)->(1,784,576), 576=3x3x64

        # final tokens
        x = self.project(x)  # (1,196,576)->(1,196,384)
        return x, hw_shape  # _, (14,14)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值