一、模型简介
1. 论文地位:VIT模型(Vision Transformer),这是一篇Google于2021年发表在计算机视觉顶级会议ICLR上的一篇文章。它首次将Transformer这种发源于NLP领域的模型引入到了CV领域,并在ImageNet数据集上击败了当时最先进的CNN网络。这是一个标志性的网络,代表transformer击败了CNN和RNN,同时在CV领域和NLP领域达到了统治地位,此后基本在ImageNet排行榜上都是基于transformer架构的模型了。
2. 论文下载链接:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE ;
3. 推荐复现的代码仓库,可以star我这个GitHub开源项目,对每行代码有详尽的注释:VIT模型详解
二、模型亮点及整体架构介绍
1. 整体架构:VIT首次将transformer模型运用到CV领域并且取得了不错的分类效果,模型原理图如图1所示。该图表示了VIT模型的整体架构,可以看出VIT只用了transformer模型的编码器部分,并未涉及解码器。其实VIT模型不难理解,只需要将其拆成三个部分(1.图像特征嵌入模块;2.transformer编码器模块;3.MLP分类模块)就可以很容易捋顺它的结构。
1.1.图像特征嵌入模块:标准的VIT模型对图像的输入尺寸有要求,必须为224*224.图像输入之后首先是需要进行patch分块,一般设置patch的尺寸为16*16,那么一共能生成(224/16)*(224/16)=196个patch块。这部分内容在代码中如何实现呢?其实很简单,就是用一个卷积层就可以实现,其卷积核大小为patch size=16, 步长为patch size=16.
具体的代码如下所示,每行代码均有详细注释,展示了图像分块和特征嵌入的完整过程,嵌入之后的特征维度是[196,768],之后我们还需要加上位置编码和类别token,前者使用直接相加的方法,后者使用concat的方法,所以加上类别token后,特征的维度变化为[197,768]:
class PatchEmbed(nn.Module): # 继承nn.Module
"""
所有注释均采用VIT-base进行说明
图像嵌入模块类
2D Image to Patch Embedding
"""
# 初始化函数,设置默认参数
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
super().__init__() # 继承父类的初始化方法
# 输入图像的size为224*224
img_size = (img_size, img_size)
# patch_size为16*16
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
# 滑动窗口的大小为14*14, 224/16=14
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
# 图像切分的patch的总数为14*14=196
self.num_patches = self.grid_size[0] * self.grid_size[1]
# 使用一个卷积层来实现图像嵌入,输入维度为[BatchSize,3,224,224],输出维度为[BatchSize,768,14,14],
# 计算公式 size= (224-16+2*0)/16 + 1= 14
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
# 如果norm_layer为True,则使用,否则忽略
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# 获取BatchSize,Channels,Height,Width
B, C, H, W = x.shape # 输入批量的图像
# VIT模型对图像的输入尺寸有严格要求,故需要检查是否为224*224,如果不是则报错提示
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
# 将卷积嵌入后的图像特征图[BatchSize, 768, 14,14],从第3维度开始展平,得到BatchSize*768*196;
# 然后转置1,2两个维度,得到BatchSize*196*768
x = self.proj(x).flatten(2).transpose(1, 2)
# norm_layer层规范化
x = self.norm(x)
return x
1.2. Transformer Encoder:主要由LayerNorm层,多头注意力机制,MLP模块,残差连接这5个知识点模块构成组成。这是整个VIT模型最重要也是最需要花精力理解的地方,要搞清楚transformer的编码器部分,首先需要搞清楚多头注意力机制。
由于注意力机制这块内容较多,引用博文供读者学习:狗都能看懂的self attention讲解 ;
详解self attention以及 multi head self attention的原理 。后续将补上对于注意力机制这部分内容自己的理解,其实理解了注意力机制也就理解了transformer架构的基本原理。两种注意力机制的原理如下图所示。
接下来我们按顺序往前讲解,LayerNorm层比较简单,直接调用pytorch中的nn.LayerNorm就可以实现。接下来看多头注意力机制的实现,如下所示。首先通过使用一个全连接层生成q,k,v的初始值,然后使用reshape和维度调换来进行调整,最后使用切片操作分别获得单独的Q,K,V,接下来就是transformer原始文章里提出的注意力机制的公式的实现了,公式如下:
class Attention(nn.Module):
# 经过注意力层的特征的输入和输出维度相同
def __init__(self,
dim, # 输入token的维度
num_heads=8, # multiHead中 head的个数
qkv_bias=False, # 决定生成Q,K,V时是否使用偏置
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention, self).__init__()
# 设置多头注意力的注意力头的数目
self.num_heads = num_heads
# 针对每个head进行均分,它的Q,K,V对应的维度;
head_dim = dim // num_heads
# 放缩Q*(K的转置),就是根号head_dim(就是d(k))分之一,及和原论文保持一致
self.scale = qk_scale or head_dim ** -0.5
# 通过一个全连接层生成Q,K,V
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# 定义dropout层
self.attn_drop = nn.Dropout(attn_drop_ratio)
# 通过一个全连接层实现将每一个头得到的注意力进行拼接
self.proj = nn.Linear(dim, dim)
# 使用dropout层
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim],其中,num_patches+1中的加1是为了保留class_token的位置
B, N, C = x.shape # [batch_size, 197, 768]
# 生成Q,K,V;qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head],
# 调换维度位置,permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# 将Q,K,V分离出来,切片后的Q,K,V的形状[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# 这个的维度相同,均为[BatchSize, 8, 197, 768]
q, k, v = qkv[0], qkv[1], qkv[2]
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
# 即文章里的那个根据q,k, v 计算注意力的公式的Q*K的转置再除以根号dk
attn = (q @ k.transpose(-2, -1)) * self.scale
# 对得到的注意力结果的每一行进行softmax处理
attn = attn.softmax(dim=-1)
# 添加dropout层
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 乘V,通过reshape处理将每个head对应的结果进行拼接处理
# 使用全连接层进行映射,维度不变,输入[BatchSize, 197, 768], 输出也相同
x = self.proj(x)
# 添加dropout层
x = self.proj_drop(x)
return x
然后是MLP模块,也就是transformer原文中的前馈网络(feed forward),这一部分其实比较简单,没什么可讲的,就是两个全连接层加上dropout层实现:
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
# 注意当or两边存在None值时,输出是不为None那个
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
# MLP模块中的第一个全连接层,输入维度in_features, 输出维度hidden_features
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
# MLP模块中的第二个全连接层,输入维度hidden_features, 输出维度out_features
x = self.fc2(x)
x = self.drop(x)
return x
残差连接模块,这个比较简单,代码实现如下:
# 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
# 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
x = x + self.drop_path(self.attn(self.norm1(x)))
# 将输入x与经过layer norm和MLP处理后的值进行残差相加
x = x + self.drop_path(self.mlp(self.norm2(x)))
好了,接下来是将这些不同功能的模块进行包装,代码和注释如下:
class Block(nn.Module):
# 集成了Transformer Encoder的所有功能
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
# 定义全连接层,输入维度为embedding dim=768,隐藏层为embedding dim*4=3072,输出层为in_features=768
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
def forward(self, x):
# 相当于resnet中的shortcut,原论文中图1的结构中的+和箭头就是这个意思。因为输入和输出的维度完全相同,所以二者可以相加
# 将输入x与经过layer norm和多头注意力处理后的值进行残差相加
x = x + self.drop_path(self.attn(self.norm1(x)))
# 将输入x与经过layer norm和MLP处理后的值进行残差相加
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
其实讲到这里,基本已经理解了VIT模型知识体系的三分之二。接下来就是最后的MLP分类模块了,这一块比较简单,甚至可以只用一层全连接层来解决。之前没有了解过VIT的小伙伴,这里需要提示一下,我们输入到MLP类别分类器中的特征只有类别token。经过N层transformer编码器处理后的特征的维度与输入前相同,均为[197,768],我们只使用列表切片的方式提取出类别token,维度为[1,768].进行下一步的类别分类。有小伙伴可能不理解,那不是其它的特征没有用到吗?浪费了是不是。其实不是,多头注意力机制可以让不同位置的特征进行全面交互,这里输出的类别token和之前输入的类别token早已发生了巨变,这种变化是由其它特征影响的。
最后提供一下,transformer模型的整体架构代码:
class VisionTransformer(nn.Module):
# 集成VIT模型架构
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_c (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_ratio (float): dropout rate
attn_drop_ratio (float): attention dropout rate
drop_path_ratio (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
"""
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
# 创建可学习的位置token,形状为 (1, num_patches + self.num_tokens, embed_dim),初始值全为0
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
# nn.Sequential作为容器用来包装层,通过列表循环构建了depth个block
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer,即pre_logits层。这个if语句是问,如果需要pre_logits层并且不进行模型蒸馏则...
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
# 跳过pre_logits层
self.pre_logits = nn.Identity()
# Classifier head(s),使用一个全连接层进行分类结果输出
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init,对位置编码进行初始化
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
# 对类别编码进行初始化
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim],即[B, 196, 768]
x = self.patch_embed(x)
# 定义类别token,维度为[1,1,768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# 注意dist_token是用于模型蒸馏的,此处设置为None
if self.dist_token is None:
# concat上类别token
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
# 添加位置编码
x = self.pos_drop(x + self.pos_embed)
# 将特征输入transformer编码器
x = self.blocks(x)
# LayerNorm层
x = self.norm(x)
# 注意dist_token是用于模型蒸馏的,此处设置为None
if self.dist_token is None:
# 返回类别token
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
# 进行forward_features之后获得训练后的类别token
x = self.forward_features(x)
# head_dist等于None,直接执行else
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
# 使用全连接层输出分类结果
x = self.head(x)
return x