paper:Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
official implementation:https://github.com/yitu-opensource/T2T-ViT
third-party implementation:https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/t2t_vit.py
存在的问题
Vision Transformer 需要大量的数据集(比如JFT-300M)才能得到超越CNN的效果,当在中等规模的数据集(比如ImageNet)上从零开始训练时性能不如CNN。 作者发现,这是因为ViT在训练时存在两个主要问题:简单的图像切分无法有效捕捉局部结构(如边缘和线条),以及ViT的冗余注意力机制设计导致特征丰富性不足,特别是在计算预算和训练样本有限的情况下。
创新点
本文解决了ViT在中等规模数据集上训练时性能不如CNN的问题。为此,作者提出了一种新的Tokens-to-Token Vision Transformer (T2T-ViT) 模型,通过改进图像切分和注意力机制设计来提升模型的训练效率和性能。具体包括:
- 逐层Tokens-to-Token(T2T)转换:通过递归地将相邻的Tokens聚合成一个Token,逐步结构化图像,以便更好地捕捉局部结构信息并减少Tokens长度。
- 高效的骨干设计:借鉴CNN架构设计,采用“深而窄”的结构,提高特征丰富性并减少冗余。
效果
T2T-ViT 在 ImageNet 数据集上从头开始训练时,与具有相似参数量的 ResNet 和 MobileNet 相比,取得了更好的性能。例如,与 ResNet50(25.5M 参数)相比,T2T-ViT(21.5M 参数)在 384×384 图像分辨率下达到了 83.3% 的 top-1 准确率,而 ResNet50 的准确率在 76%-79% 之间。此外,T2T-ViT 还成功地将 ViT 的参数数量和MACs(Multiply-Accumulate Operations)减少了一半,同时实现了超过 3.0% 的性能提升。
方法介绍
为了克服ViT简单的tokenization的局限性以及低效的backbone架构,作者提出了Tokens-to-Token Vision Transformer(T2T-ViT)。如图4所示,T2T-ViT有两个主要的组成部分:1)一个逐层的“Tokens-to-Token module(T2T module)”用来对图像的局部结构信息进行建模,并逐步减少tokens的长度。2)一个高效的“T2T-ViT backbone”从T2T module中提取tokens的全局注意力关系。在探索了几种不同的CNN架构后,采用了一种深而窄deep-narrow的骨干结构来减少冗余,提高特征丰富性。
Tokens-to-Token: Progressive Tokenization
每个T2T过程包括两步:Re-structurization 和 Soft Split,如图3所示
Re-structurization 如图3所示,给定一个来自前一个transformer layer的tokens序列 \(T\),它将通过self-attention block(图3中的T2T transformer)进行转换:
其中 MSA 表示多头自注意力操作,MLP表示多层感知器,两者都采用layer normalization。然后tokens \(T'\) 被reshape成空间维度中的图像
其中“Reshape”将tokens \(T'\in\mathbb{R}^{l\times c}\) 转换成 \(I\in\mathbb{R}^{h\times w\times c}\),其中 \(l\) 是tokens \(T'\) 的长度,\(h,w,c\) 分别表示高、宽和通道数且 \(l=h\times w\)。
Soft Split 如图3所示,在得到重构后的图像 \(I\) 后,我们对其进行 soft split 来对局部结构信息进行建模并减小tokens的长度。具体来说,为了避免从重构后的图像中生成tokens时丢失信息,我们将其split成重叠的patch。这样每个patch都和周围的patch相关联,从而建立了一个先验prior:相邻的tokens之间应该具有强相关性。然后每个patch的token拼接到一起得到一个token,从而可以从周围的像素和patch中聚合局部信息。
当进行软分割时,每个patch的大小为 \(k\times k\),patch之间的重叠为 \(s\),图片的padding为 \(p\),其中 \(k-s\) 和卷积中的stride类似。所以对于重构的图片 \(I\in\mathbb{R}^{h\times w\times c}\),soft split后的输出token \(T_o\) 的长度为
每个split patch的大小为 \(k\times k\times c\),我们将所有的patch沿空间维度展平得到 \(T_o\in\mathbb{R}^{l_o\times ck^2}\)。在soft split后,输出进入到下一个T2T。
T2T module 通过多次迭代上述Re-structurization和Soft Split,T2T module可以逐步减少token的长度,变换图像的空间结构。T2T module的迭代过程如下
对于输入图片 \(I_0\),我们首先通过一个soft split得到tokens:\(T_1=SS(I_0)\)。在迭代多次后最后一个iteration,T2T module的输出token \(T_f\) 具有一个固定长度,所以T2T-ViT的backbone可以通过 \(T_f\) 建模全局关系。
此外,由于T2T module中的token长度大于ViT中的一般情况(16x16),因此MACs和内存开销都更大。为了解决这个问题,作者将T2T module中T2T layer的通道数设置的较小(32或64)来减小MACs。
T2T-ViT Backbone
作者探索了CNN中五种不同的结构设计,具体包括:
- DenseNet中的Dense connection
- Deep-narrow结构和Wide-ResNets中的shallow-wide结构
- 通道注意力比如SE
- 多头注意力中分配更多的head,像ResNeXt一样
- GhostNet中的Ghost operation
通过实验,作者观察到:1)在ViT中采用deep-narrow结构减少了通道冗余,增强深度提高了特征丰富性,降低模型尺寸和MACs的同时模型性能还提高了。2)SE block对ViT的性能提升也有帮助,但不如deep-narrow大。基础这些发现,作者将T2T-ViT设计成deep-narrow的结构,具体来说,它有一个较小的通道数和隐藏层的维度 \(d\),但有更多的层 \(b\)。对于T2T module最后一层的输出具有固定长度的token \(T_f\),我们将其与一个class token拼接到一起,并加上位置编码,就像ViT中做的一样
其中 \(E\) 是正弦位置编码,\(LN\) 是layer normalization,\(fc\) 是全连接层,\(y\) 是预测输出。
T2T-ViT Architecture
如图4所示,T2T-ViT由两部分组成,T2T module和T2T-ViT backbone。对于T2T模块有多种设计选择,如图4所示这里我们选择 \(n=2\),这意味着T2T module中有 \(n+1=3\) 个soft split和 \(n=2\) 个re-structurization。三个soft split中的patch size为 \(P=\{7,3,3\}\),重叠大小 \(S=[3,1,1]\),根据式(3)这将输入图片从224x224减小到14x14。
为了和CNN公平比较,我们让T2T-ViT的大小与ResNets和MobileNets相当,具体来说,我们设计了三种模型:T2T-ViT-14、T2T-ViT-19、T2T-ViT-24,参数量分别和ResNet-50、ResNet-101、ResNet-152相当。此外还设计了两个轻量模型:T2T-ViT-7和T2T-ViT-12,参数量分别和MobileNetV1和MobileNetV2相当。具体的参数配置如表1所示
实验结果
在ImageNet上和ViT以及DeiT的结果对比如表2所示,可以看到在大约相同的参数量下,T2T-ViT的精度超过了ViT和DeiT。
和ResNet系列的对比如表3,T2T-ViT的精度也更高。
代码解析
这里的代码是MMPretrain中的实现。其中soft split通过nn.Unfold实现,具体介绍见torch.nn.functional.unfold 用法解读-CSDN博客。一共包含3次soft split和2次re-structurization,输入大小为(1, 3, 224, 224),经过第一个soft split和re-structurization后得到输出(1, 64, 56, 56),经过第二次soft split和re-structurization后得到输出(1, 64, 28, 28),经过第三次soft split后得到输出(1, 196, 576),其中196=14x14,576=3x3x64,注意这次没有re-structurization了,然后就进入了self.project。
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
out_side = img_size // (4 * 2 * 2)
self.init_out_size = [out_side, out_side]
self.num_patches = out_side**2
@staticmethod
def _get_unfold_size(unfold: nn.Unfold, input_size):
h, w = input_size # 224,224
kernel_size = to_2tuple(unfold.kernel_size) # (7,7)
stride = to_2tuple(unfold.stride) # (4,4)
padding = to_2tuple(unfold.padding) # (2,2)
dilation = to_2tuple(unfold.dilation) # (1,1)
h_out = (h + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
# (224 + 2x2 - 1x(7-1) - 1) // 4 + 1 = 56
w_out = (w + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
return (h_out, w_out)
def forward(self, x): # (1,3,224,224)
# step0: soft split
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:]) # (56,56)
x = self.soft_split0(x).transpose(1, 2) # (1,147,3136)->(1,3136,147), 3136=56x56
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2) # (1,3136,147)->(1,3136,64)->(1,64,3136)
B, C, _ = x.shape
x = x.reshape(B, C, hw_shape[0], hw_shape[1]) # (1,64,56,56)
# soft split
soft_split = getattr(self, f'soft_split{step}')
hw_shape = self._get_unfold_size(soft_split, hw_shape) # (28,28)
x = soft_split(x).transpose(1, 2) # (1,576,784)->(1,784,576), 576=3x3x64
# final tokens
x = self.project(x) # (1,196,576)->(1,196,384)
return x, hw_shape # _, (14,14)