上次在中原始VIT(Vision Transformer)总结(原理与代码)_vision transformer代码 github-CSDN博客介绍了一下原版VIT的原理和代码,这次介绍一下基于原始ViT改进的一个端侧的模型,论文针对原始VIT存在的一些缺点进行相应的改进。这篇论文是Apple出的,论文整体的质量很不错,一些观点还是很值得借鉴的,笔者也是很推荐大家能够读一读原文。
原文:https://arxiv.org/pdf/2110.02178
要知道怎么改进一个事物,先要了解原有的事物的优缺点,首先我们来看一下原始Vision Transformer存在的优缺点。
优点:强大的全局表征能力,长依赖关系的学习,在数据量够大的时候,效果能做到比现有的CNN更好。
缺点:1. 缺少空间上的归纳偏置(inductive bias, 归纳偏置更有利于学习到局部空间表征); 而CNN却有空间上的归纳,有不错的局部空间表征能力。2. 以self-attention为基础的vit,计算量大,推理延迟高,self-attention的算法时间复杂度达到了,这里d是embedding的维度,n是feature map的面积(h * w)。可见在稍微大一点的分辨率上,用vit的使用是非常吃计算资源的。
针对原始VIT存在的这些缺点,MobileVIT这篇文章,发出了拷问(见原文,如黄色高亮部分如下):
能否把CNN和ViTs的长处结合起来,构建一个轻量级低延迟的用于手机端侧的视觉任务的网络?围绕着这个问题,论文作者们提出了MobileVIT。并且这个网络,在不同的视觉任务上,都明显好与纯CNN以及纯ViT的网络效果。
1. 网络结构介绍
如下Figure 1(b)就是MobileViT的整体结构。整体结构看起来非常的简单,黄色块是卷积,红色块是MobileNetV2中的MVConv模块,绿色的部分就论文的重点MobileViT block, 在下面会详细介绍。整个结构就是一个普通的layer by layer, block by block的结构,最后是一个分类的head(global pool + linear)。接下来重点介绍Mobile block结构。
1.1 MobileViT block原理介绍:
MobileViT模块就是论文中为了使用更少的参数,做到更好的局部和全局特征提取能力,而提出来的一种结构,也是论文主要的核心部分。作者们想要用一种高效的方式获取长范围非局部的感受野(long-range non-local)。而现有的方式中,膨胀卷积(dilated convolutions)是一种常见的方式。但是膨胀卷积需要特别注意设置这个膨胀系数(dilation rates), 否则权重很容易作用在pdding的零上。然后另一种实现长范围非局部的感受野的方法就是VIT中的self-attention了,而vit的缺陷,在开始已经介绍了,计算量大并且缺乏空间归纳偏置。
那么为了让MobileViT可以在有空间归纳偏置的情况下也能学习全局表征的能力,作者们是这样做到:
首先输入到MobileViT block模块的特征图会经过卷集模块,来获取局部的特征,得到特征图特征图(feature Map)。对于, 将其分成N个不相互重叠的拉平的小块(patch), 其中 w h分别是小块的宽高,为小块的数量。而patch内部的关系,用transformer来进行编码:
和ViT会失去像素之间的空间顺序不同的是,MobileViT既不会失去patch顺序,也不会失去每个patch中像素之间的空间顺序。经过transformer之后,再将transformer的输出又重新变回到(B, C, H, W)维度的feature map 。然后又通过一个1x1的卷积,并且和输入MobileViTblock的输入形成残差合并(concatention)。之后再使用一个卷积操作将这个合并的特征进行融合。可以看到通过卷积编码了局部的信息,而通过transformer编码了全局信息。这样就实现了全局和局部信息的特征提取能力。
整个网络, 论文提供了三个参数配置,如下所示:
1.2 MobileViT block代码介绍:
我们首先看一下Figure1画的MobileViT block的图。这个图配合上面的原理介绍,应该大致就知道怎么来实现这个代码了。从左到右可分为5个部分:
-
首先这个local representations, 里面就是一个3x3的卷积 + 1x1的卷积,这个很好实现;
-
Unfold部分:就是一个feature map的reshape过程,需要将的feature map变形为;
-
Transformer encoder部分,这个和之前原始ViT的实现一模一样,这个可以参考上一篇文章的代码实现;
-
Fold部分:又把这个形状变回到(B, C, H, W);
-
一个1x1conv + 和输入的concat + 用来作特征融合的3x3conv。
根据以上的拆分,笔者也用pytorch实现了一下MobileViT block的代码,如下仅供大家参考:
class MobileViTBlock(nn.Module):
def __init__(self,
input_dim,
transformer_dim,
head_dim=32,
ffn_dim=128,
attn_dropout=0.0,
ffn_dropout=0.0,
num_transformer=2,
patch_size=(8, 8),
no_fusion=False
):
super().__init__()
self.local_rep = nn.Sequential(
ConvLayer2d(in_channels=input_dim,
out_channels=input_dim,
kernel_size=3),
ConvLayer2d(in_channels=input_dim,
out_channels=transformer_dim,
kernel_size=1,
use_act=False, use_norm=False)
)
num_heads = transformer_dim // head_dim
self.global_rep = nn.Sequential(*[
TransformerEncoder(
transformer_dim, transformer_dim,
num_head=num_heads, head_dim=head_dim, hidden_dim=ffn_dim,
attn_drop_rate=attn_dropout, drop_rate=ffn_dropout
) for _ in range(num_transformer)
])
self.conv_proj = ConvLayer2d(in_channels=transformer_dim,
out_channels=input_dim,
kernel_size=1)
self.patch_size = patch_size
self.fusion = None
if not no_fusion:
self.fusion = ConvLayer2d(in_channels=2 * input_dim, out_channels=input_dim, kernel_size=3)
def forward(self, x):
res = x
fm = self.local_rep(x)
patches, info_dict = unfolding(fm, self.patch_size)
patches = self.global_rep(patches)
fm = folding(patches, info_dict)
fm = self.conv_proj(fm)
if self.fusion is not None:
fm = self.fusion(torch.cat((res, fm), dim=1))
return fm
基本上是按照上面的拆解写的。其中的unfolding和folding函数以及整体的整个网络,读者可以尝试自己去实现一下,这里由于篇幅有限,就不过多介绍了。如果需要参考的话,大家可以参考官网,而笔者也在自己的github代码仓写了一个简单的版本,大家也可以自行参考。
2. 其他
其他的还有一个多尺度训练样本的问题,这个也算是一个训练方法的改进。并且对于分布式的训练,会有根据输入样本的尺寸调整每张卡上的batch_size的方法,这样做有利于提高训练效率。以及一些实验细节,这里笔者就不再赘述了,感兴趣的同学直接阅读原论文哈。
最后,希望大家可以一起学习,一起探讨,一起进步,也希望我的文章能够给大家在学习上带来一点参考价值。觉得写的不错的同学可以点赞收藏加个关注,谢谢各位同学,咱们下一篇文章见!