paper:Rethinking Spatial Dimensions of Vision Transformers
official implementation:https://github.com/naver-ai/pit
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/pit.py
出发点
这篇论文的出发点是探索如何在视觉Transformer(Vision Transformer,简称ViT)架构中有效地引入空间维度转换,从而改进其性能。ViT已经在计算机视觉任务中展示了其强大的能力,但在设计上并没有充分利用卷积神经网络(CNN)的一些成功经验,尤其是空间维度和通道维度的转换原则。
解决了什么问题
论文解决了如何在ViT中引入空间维度转换的问题。具体来说,ViT在不同层之间保持相同的空间维度,而这与CNN的设计原则不同。CNN通过逐渐减小空间维度并增加通道维度来提升模型的表达能力和泛化性能。论文通过引入一种基于池化层的Vision Transformer(Pooling-based Vision Transformer,简称PiT),验证了这种空间维度转换在ViT中的有效性。
创新点
- 引入池化层到ViT:提出了一种新的池化层设计,使得ViT能够像CNN一样进行空间维度的减小和通道维度的增加。
- 设计PiT架构:设计了PiT架构,通过引入池化层来实现ViT的空间维度转换。
- 实验验证:通过广泛的实验,验证了PiT在多个任务上的性能提升,包括图像分类、目标检测和鲁棒性评估。
效果
PiT在多个方面的性能优于原始的ViT:
- 模型能力:在相同的计算成本下,PiT的训练损失更低,表明其模型能力更强。
- 泛化性能:PiT在训练精度和验证精度上均表现出更好的泛化性能。
- 任务表现:在ImageNet分类任务上,PiT在不同规模和训练环境下均优于ViT;在COCO数据集上的目标检测任务中,PiT作为骨干网络的表现也优于ViT。此外,PiT在鲁棒性评估中也表现出色。
方法介绍
在CNN中,常用的架构设计是随着网络层的加深,分辨率逐渐减小,通道数逐渐增大。通常有两种方式实现分辨率的下采样,一种是池化层,另一种是步长为2的卷积。而在ViT中只在网络一开始通过patch embedding减少了分辨率,在之后所有的阶段中分辨率都保持不变。
作者首先探索了ResNet-style和ViT-style两种不同的维度转换方式。作者将ResNet-50中所有的下采样全部去掉,并在网络一开始用ViT中的patch embedding层将分辨率缩小到14x14,然后与原始的ResNet-50进行对比,结果如图2所示。
从图2(a)可以看出,原始的ResNet在相同的计算开销(FLOPs)下训练损失更低,表明ResNet-style增强了架构的capability。如图2(b)所示,ResNet-style的验证精度更高,表明ResNet-style的维度变化方式有助于模型的泛化。综上所示,ResNet-style的维度变化提高了模型的能力和泛化性能,从而显著提高了模型的精度,如图2(c)所示。
为了将ResNet-style的维度变化方式引入ViT,作者提出了一种新的架构Pooling-based Vision Transformer(PiT)。首先为ViT设计了一个池化层,如图4所示
首先将序列形式的特征reshape成2D特征图的形式,然后通过步长为2的深度卷积实现spatial reduction,然后再reshape为序列形式。而对于class token,当它和spatial token一起时维度无法精确的reshape回2D特征图,因此将class token单独拿出来,只有spatial token进行深度卷积,class token通过一个全连接层来对齐维度,最后再与spatial token拼接到一起。
实验结果
不同尺度的PiT的配置如表1所示
在ImageNet的结果如下表所示,可以看到在不同的scale下,PiT的效果都超过了ViT。
代码解析
这里是timm中的实现。其中每个stage由一个Transformer类构成,实现如下。可以看到当self.pool不为空时即需要进行下采样,将spatial token和class token分别送入池化层,得到输出后再concat到一起,然后再经过self.blocks,self.blocks中是若干层Attention+MLP,得到的输出再split开,进入下一个stage的Transformer中。
class Transformer(nn.Module):
def __init__(
self,
base_dim,
depth,
heads,
mlp_ratio,
pool=None,
proj_drop=.0,
attn_drop=.0,
drop_path_prob=None,
norm_layer=None,
):
super(Transformer, self).__init__()
embed_dim = base_dim * heads
self.pool = pool
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim,
num_heads=heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_prob[i],
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
for i in range(depth)])
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
x, cls_tokens = x
token_length = cls_tokens.shape[1] # 1
if self.pool is not None:
# (1,144,27,27),(1,1,144)
x, cls_tokens = self.pool(x, cls_tokens)
# (1,288,14,14),(1,1,288)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # (1,288,196)->(1,196,288)
x = torch.cat((cls_tokens, x), dim=1) # (1,197,288)
x = self.norm(x)
x = self.blocks(x) # (1,197,288)
cls_tokens = x[:, :token_length] # (1,1,288)
x = x[:, token_length:] # (1,196,288)
x = x.transpose(1, 2).reshape(B, C, H, W) # (1,288,196)->(1,288,14,14)
return x, cls_tokens
池化层的实现如下,可以看到spatial token经过3x3-s2的depthwise convolution,而class token经过一个fc,封闭得到对应的输出。
class Pooling(nn.Module):
def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
super(Pooling, self).__init__()
self.conv = nn.Conv2d(
in_feature, # 144
out_feature, # 288
kernel_size=stride + 1, # 3
padding=stride // 2, # 1
stride=stride, # 2
padding_mode=padding_mode,
groups=in_feature, # 144
)
self.fc = nn.Linear(in_feature, out_feature)
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: # (1,144,27,27)
x = self.conv(x) # (1,288,14,14)
cls_token = self.fc(cls_token) # (1,1,288)
return x, cls_token
另外多提一嘴,以PiT-S为例,此时经过patch embedding后第一个stage的输入大小为27x27,这里patch embdding是通过conv实现的,代码如下。其中kernel_size=patch_size=16,stride=8,原始输入图片大小为224x224,注意这里padding=0,所以最后得到的输出大小为27x27,而不是28x28。
class ConvEmbedding(nn.Module):
def __init__(
self,
in_channels,
out_channels,
img_size: int = 224,
patch_size: int = 16,
stride: int = 8,
padding: int = 0,
):
super(ConvEmbedding, self).__init__()
padding = padding
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1)
self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1)
self.grid_size = (self.height, self.width)
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=patch_size,
stride=stride, padding=padding, bias=True)
def forward(self, x):
x = self.conv(x) # (1,144,27,27), kernel_size=16,stride=8,padding=0, 卷积核滑动到最后一个位置时一半在特征图里一半在外面,所以结果是27而不是28
return x