论文:《An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale》
源码:https://github.com/google-research/vision_transformer
Pytorch代码:pytorch_classification/vision_transformer
目录
(1)Patch Embedding(Linear Projection of Flattened Patches)
1、背景
Transformer模型在NLP领域取得了极大地成功,如何将自注意力应用在CV领域还是一个丞待解决的问题。当前采取的方法主要是将自注意力与卷积神经网络一起使用,或者将某些卷积替换为自注意力,但保持模型的整体架构不变。
VIT证明,尽可能保留Transformer的原始结构,并将其用在视觉领域,能够在ImageNet数据集的分类任务中取得最好的性能。
2、VIT与CNN的不同
(1)局部感受野:CNN只和位置相近的区域进行运算,如果要与很远的信息进行交互,就需要加深模型,而Transformer只需要一次自注意力即可与所有位置的特征进行交互;
(2)CNN需要配合池化,不断减低特征维度,而Transformer不需要池化。
(3)Transformer包含更多的参数,需要强大的算力支持;
(4)Transformer最大的优势在于整合了CV和NLP领域,能够仅使用一套模型处理多模态任务。(而纯CV领域仍使用CNN,因为CNN速度快,且计算资源少)
(5)而VIT缺少CNN的归纳偏置(一种先验知识):
- 局部相似性:CNN是以滑动窗口的方式在图像上进行卷积,假设图片相邻的区域有着相似的特征;
- 平移不变性:相同的输入×卷积核=相同的输出
3、网络结构
将图像分成patch,将patch作为序列进行输入,每个patch通过一个可学习的线性投射层(Linear Projection)后得到一个特征(token),也就是patch embedding,再加上一个位置编码(positional encoding)之后,传入Transformer Encoder得到一个输出。同时在序列最前端增加了一个可学习的class embedding(cls),因为所有的token都在和其他所有的token做交互,所以cls可以从别的embedding中学到有用的信息,从而只需要根据它的输出作为最后的判断,比如接一个分类头进行分类,最后使用交叉熵损失函数进行模型的训练。
(1)Patch Embedding(Linear Projection of Flattened Patches)
该层的作用是将patch映射为序列化的向量表示。
- 输入图像大小:224×224×3
- 每个patch大小:16×16×3=768 → tokens
- 共(224/16) × (224/16) = 14 × 14 = 196个patch
- patch输入:X=196×768
在此,Linear Projection通过一个全连接层/卷积层(E=768×768)对输入的patch进行特征编码,得到196×768(X·E)的输出作为patch embedding(196个长度为768的token)。
代码分析:
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = (img_size, img_size) # (224,224)
patch_size = (patch_size, patch_size) # (16,16)
self.img_size = img_size # (224,224)
self.patch_size = patch_size # (16,16)
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # (14,14)
self.num_patches = self.grid_size[0] * self.grid_size[1] # 196
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape # (1, 3, 224, 224)
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = self.proj(x) # (1, 768, 14, 14)
x = x.flatten(2) # (1, 768, 196)
x = x.transpose(1, 2) # (1, 196, 768)
x = self.norm(x) # (1, 196, 768)
return x
(2)Class token
class token的设计完全是从NLP领域借鉴过来的。在视觉领域,一般是将最后一个输出的特征图做一个Global Average Pooling,再经过一个线性分类器进行输出。但为了和原始的Transformer结构保持尽可能的一致,借鉴了bert中class token的操作,它能够从其他token中学到有用的特征,作为整张图像的特征表示。
cls token是一个大小为(1, 768)的向量,直接将其与其他tokens拼接在一起:cat[(1, 768), (196, 768)] = (197, 768)。
代码分析:
def __init__():
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # (1, 1, 768)
nn.init.trunc_normal_(self.cls_token, std=0.02) # 使用截断正态分布初始化权重,标准差为0.02,(1, 1, 768)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # (1, 1, 768),初始化为0
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # (1, 1, 768)
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768], (1, 197, 768)
(3)Position Embedding
位置编码用于保留输入图像块之间的空间位置信息。由于Transformer不具有CNN的先验知识,所以需要位置嵌入来编码patch tokens的位置信息,这主要是由于自注意力的扰动不变性(Permutation-invariant),即打乱tokens的顺序并不会改变结果。
具体做法是设置一张表,表的行表示1、2、3、……序号,每一行的数据是一个长度为768的可学习的向量(197×768),然后把位置信息加入到token中(sum)。ViT 采用了标准可学习/训练的1D位置编码。
- 无位置编码:对位置不敏感
- 1D位置编码:考虑把 2-D 图像块视为 1-D 序列;
- 2D位置编码:考虑图像块的 2-D 位置 (x, y);
- 相对位置编码:考虑图像块的相对位置。
代码分析:
def __init__():
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) # (1, 197, 768),初始化为0
nn.init.trunc_normal_(self.pos_embed, std=0.02) # 使用截断正态分布初始化权重,标准差为0.02,(1, 197, 768)
self.pos_drop = nn.Dropout(p=drop_ratio)
def forward_features(self, x):
……
# x: embedded patches + cls token(1, 197, 768)
x = self.pos_drop(x + self.pos_embed) # (1, 197, 768)
(4)Transformer Encoder
Transformer Encoder包括L个Transformer block,每个Transformer block包含两个操作:Multi-head Attention和MLP,在进行具体操作之前,都要先经过LayerNorm,并在操作之后进行残差连接。
- 输入特征大小:197 × 768
- Multi-Head Attention:假设有N个头,K、Q、V:197 × (768 / N),拼接后仍然输出197×768
- MLP:197×768 → 197 × 3072 → 197 × 768
- L层输出特征大小:197 × 768
代码分析:
1、Transformer Encoder初始化
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
2、 Transformer block
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
def forward(self, x):
input = x # (1, 197, 768)
x = self.norm1(x) # (1, 197, 768)
x = self.attn(x) # (1, 197, 768)
x = self.drop_path(x) # (1, 197, 768)
x = x + input # (1, 197, 768)
input = x # (1, 197, 768)
x = self.norm2(x) # (1, 197, 768)
x = self.mlp(x) # (1, 197, 768)
x = self.drop_path(x) # (1, 197, 768)
x = x + input # (1, 197, 768)
return x
3、Multi-Head Attention
class Attention(nn.Module):
def __init__(self,
dim, # 768
num_heads=8, # 8个头,实例化的时候是12个头
qkv_bias=False,
qk_scale=None, # 根号dk,0.125
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
# 计算每个head的dim,直接均分操作。
head_dim = dim // num_heads
# 计算分母,q,k相乘之后要除以一个根号下dk。
self.scale = qk_scale or head_dim ** -0.5
# 直接使用一个全连接实现q,k,v。
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
# 多头拼接之后通过W进行映射,跟上面的q,k,v一样,也是通过全连接实现。
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
B, N, C = x.shape # (1, 197, 768)
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
qkv = self.qkv(x) # (1, 197, 768*3)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) # (1, 197, 3, 12, 64)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, 1, 12, 197, 64)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2] # (1, 12, 197, 64)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) # 矩阵乘法+转置,(1, 12, 197, 197)
attn = attn * self.scale # /根号dk,(1, 12, 197, 197)
attn = attn.softmax(dim=-1) # (1, 12, 197, 197)
attn = self.attn_drop(attn) #(1, 12, 197, 197)
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
# print((attn @ v).shape)
x = attn @ v # (1, 12, 197, 64)
x = x.transpose(1, 2) # (1, 197, 12, 64)
x = x.reshape(B, N, C) # 多头拼接,(1, 197, 768)
# print(x.shape)
x = self.proj(x) # (1, 197, 768)
x = self.proj_drop(x) # (1, 197, 768)
return x
4、MLP
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features # 4*768
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x): # (1, 197, 768)
x = self.fc1(x) # (1, 197, 3072)
x = self.act(x) # (1, 197, 3072)
x = self.drop(x) # (1, 197, 3072)
x = self.fc2(x) # (1, 197, 768)
x = self.drop(x) # (1, 197, 768)
return x
(5)MLP Head
从Transformer第L层输出的第0个位置的元素,也就是cls token (1, 768)中获得结果信息(使用MAP也可)。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成,但是迁移到ImageNet1K数据集或者自己的数据上时,只用一个Linear即可。
代码分析:
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # 768
self.num_tokens = 2 if distilled else 1 # 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # (1, 1, 768),初始化为0
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) # (1, 197, 768),初始化为0
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02) # 使用截断正态分布初始化权重,标准差为0.02,(1, 197, 768)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02) # 使用截断正态分布初始化权重,标准差为0.02,(1, 1, 768)
self.apply(_init_vit_weights)
def forward_features(self, x):
# [B, C, H, W] -> [B, num_patches, embed_dim]
x = self.patch_embed(x) # (1, 196, 768)
# [1, 1, 768] -> [B, 1, 768]
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # (1, 1, 768)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768], (1, 197, 768)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed) # (1, 197, 768)
x = self.blocks(x) # (1, 197, 768)
x = self.norm(x) # (1, 197, 768)
# 提取cls token对应的输出,也就是类别向量。
if self.dist_token is None:
# 返回所有的batch维度和第二个维度上面索引为0的数据(cls token)
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x): # (1, 2, 224, 224)
x = self.forward_features(x) # (1, 768)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x) # (1, 1000)
return x
参考: