在这里分享一下我对于ViT骨干网络的代码理解,ViT paper发表于2020年,掀起了transformer结构在视觉任务中的应用潮流。
paper:https://arxiv.org/abs/2010.11929
参考源码:https://github.com/bubbliiiing/classification-pytorch/tree/main/nets
2022.3.7进行了代码注释的更新。
# -------------------------------------------------------------------------------#
# ViT将transformer应用到图像图块(patch)之上来完成图像分类任务,将patch看作NLP中的token
# ViT将input每隔一定的区域大小划分图片块,之后将划分后的图片块组合成序列,这一步通过一层conv来实现
# 将组合后的序列传入transformer特有的Multi-head Self-attention(多头自注意力机制)
# 进行特征提取,最后利用Cls Token进行分类
# 这样的特性使得ViT可以作为通用的视觉任务Backbone进而完成下游任务(目标检测,语义分割...)
# 组成ViT的两个部分:1.特征提取部分;2.分类部分
# 在这里仅仅使用了TRM中的编码器架构而没有使用解码器架构,在BERT中其也只是使用了编码器架构
# -------------------------------------------------------------------------------#
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
#---------------------------------------#
# GELU激活函数,利用近似的数学公式得以实现
# 在transformer block的全连接层中加以运用
#---------------------------------------#
class GELU(nn.Module):
def __init__(self):
super(GELU, self).__init__()
def forward(self, x):
return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
#---------------------------------------#
# drop_path函数将在DropPath这个类中加以运用
#---------------------------------------#
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
#---------------------------------------#
# DropPath类
#---------------------------------------#
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
# --------------------------------------------------------#
# 图像Patch化+Embedding模块的实现
# 首先进行普通的二维卷积操作使其patch化
# 之后将这个特征层组合成序列,做法为将高宽维度进行平铺(flatten操作)
# 但并未嵌入位置信息
# --------------------------------------------------------#
class PatchEmbed(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_channels=3, num_features=768, norm_layer=None,flatten=True):
super(PatchEmbed,self).__init__()
self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
self.flatten = flatten
# ----------------------------------------#
# 投影即为一个卷积操作,其中padding=0,k=16,s=16
# 进而output的shape为Batch_size,768,14,14
# 在flatten并转置之后,shape为bs,196,768
# ----------------------------------------#
self.proj = nn.Conv2d(in_channels, num_features, kernel_size=patch_size, stride=patch_size)
#-----------------------------#
# 此处为只加上一层,但无实际作用
# 通过nn.Identity()加以实现
# 当norm_layer有定义时方可起作用
#-----------------------------#
self.norm = norm_layer(num_features) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(start_dim=2,end_dim=3).transpose(1, 2) # BCHW -> B(HW)C
x = self.norm(x)
return x
# --------------------------------------------------------------------------------------------------------------------#
# Self-Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value.其中query是查询向量、key是键向量、v是值向量.可以通过全连接操作来获得
# 首先利用查询向量query点乘转置后的键向量key,这一步可以通俗的理解为:利用查询向量去查询序列的特征,获得序列每个部分的重要程度score.
# 然后利用score点乘value,这一步可以通俗的理解为:将序列每个部分的重要程度重新施加到序列的值上去.
# 在这里,我们的注意力头的数量为12,且经过patch化,flatten以及添加cls token之后,此时张量的shape为bs,197,768
# 在forward函数中,按照transformer公式进行实现并穿插全连接层(Linear)以及Dropout层进行实现
# 在经过整个自注意力模块前后,特征图像的shape不发生改变
# --------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
def __init__(self, dim, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
super(Attention,self).__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5 #scale=0.125
#----------------------------------------#
# nn.Linear从输入输出张量的角度来讲,相当于为
# in_features->out_features
# 其所处理的为最后一个维度
# bias=False,则为将不会学习额外的bias
# 即为 768 -> 768*3
#----------------------------------------#
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#---------------------------------------------#
# nn.Dropout为一种正则化的手段,目的为减少神经元对部分
# 上层神经元的依赖进而降低过拟合的风险
# 并且其前后的shape不发生改变
# 一般在Linear层之后均会有一层Dropout
#---------------------------------------------#
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
#--------------------------------------------#
# 首先经过qkv()之后,shape为bs,197,2304(768*3)
# 之后将3单独拿出来放到第三维度之上
#--------------------------------------------#
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#--------------------------#
# q,k,v为三个shape完全相同的张量
#--------------------------#
q, k, v = qkv[0], qkv[1], qkv[2]
#-------------------------------------------------------------#
# 查询向量query点乘转置后的键向量key,之后除以scale,对应公式的中的分母
# scale即为self-attention公式中的分母,(dim // num_heads) ** -0.5
# 即为获得重要程度score,之后对其取一个softmax来表明对当前input的贡献
# 过Dropout来降低过拟合的风险
# 分别对应下面的三行code
# softmax(dim=-1)代表沿着最后一个维度运用softmax
#-------------------------------------------------------------#
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
#-------------------------------#
# 将所得score与value向量点乘
# 之后过全连接层以融合所学习到的特征
# 过Dropout来降低过拟合的风险
#-------------------------------#
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
#---------------------------------------#
# 此为transformer结构中的全连接层的定义
# 即为transformer结构中的MLP,对应自然语言处理
#---------------------------------------#
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)
self.fc1 = nn.Linear(in_features, hidden_features)
#-----------------------------------------------------------------#
# 在这里使用全连接层的目的为改变特征图像的通道数与卷积的作用相同
# 但全连接层的参数量过大,会导致过拟合,所以在其后加上Dropout以求降低过拟合的风险
# 此处所使用的为GELU激活函数
# drop_probs为定义的一个元组,其中的两个值均为drop
# drop_probs[0]以及drop_probs[1]的值均为0且为浮点数
#-----------------------------------------------------------------#
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
#---------------------------------------#
# transformerBlock的实现
# 典型的transformer结构中的Block数量为12
#---------------------------------------#
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
super(Block,self).__init__()
#--------------------------------------#
# 按照结构图进行构建,norm->att->norm->MLP
#--------------------------------------#
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(self, input_shape=[224, 224], patch_size=16, in_channels=3, num_classes=1000, num_features=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):
super(VisionTransformer,self).__init__()
# -----------------------------------------------#
# bs, 224, 224, 3 -> bs, 196, 768
# -----------------------------------------------#
self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_channels=in_channels,num_features=num_features)
num_patches = (224 // patch_size) * (224 // patch_size)
self.num_features = num_features
self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]
# --------------------------------------------------------------------------------------------------------------------#
# cls token部分是transformer的分类特征,用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取
# 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺(flatten),一幅图片会存在一个序列长度为196的特征序列
# 此时生成一个cls token,将cls token堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
# 在特征提取的过程中,cls token会与图片特征进行特征的交互.最终分类时,我们取出cls token的特征,利用全连接进行分类
# --------------------------------------------------------------------------------------------------------------------#
# --------------------------------#
# 196, 768 -> 197, 768
# 将cls token堆叠至序列化后的图片特征中
# --------------------------------#
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))
# --------------------------------------------------------------------------------------------------------------------#
# 为网络提取到的特征添加上位置信息
# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768.加上cls token后就是197, 768
# 此时生成的pos_Embedding的shape也为197, 768.其代表每一个特征的位置信息。
# --------------------------------------------------------------------------------------------------------------------#
# 197, 768 -> 197, 768
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))
self.pos_drop = nn.Dropout(p=drop_rate)
# -----------------------------------------------------#
# 此为在位置信息嵌入完毕之后,过12次transformer_encoder模块
# 197, 768 -> 197, 768 12次
# -----------------------------------------------------#
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(
*[
Block(dim=num_features,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer
) for i in range(depth)
]
)
self.norm = norm_layer(num_features)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
#-----------------------------------------------#
# 将输入patch化,并将cls token堆叠到序列化后的图片特征中
#-----------------------------------------------#
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
#------------------------------#
# 图片中的种类信息以及图像信息嵌入
#------------------------------#
cls_token_pe = self.pos_embed[:, 0:1, :]
img_token_pe = self.pos_embed[:, 1:, :]
#---------------------------------------#
# 对tensor维度进行reshape之后又进行了置换操作
#---------------------------------------#
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
#------------------------------------------#
# 将输入的tensor转换成new_feature_shape
# 所使用的算法为:双三次插值算法
#------------------------------------------#
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
#------------------------------------------#
# 将输入的tensor转置并进行扁平化操作
#------------------------------------------#
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
#---------------------------------------------#
# 将cls_token_pe以及img_token_pe在维度一上进行堆叠
#---------------------------------------------#
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
#---------------------------------------------#
# 将生成的位置信息进行嵌入
# 之后经过12个transformer_encoder模块
# 之后再过一层LayerNorm
#---------------------------------------------#
x = self.pos_drop(x + pos_embed)
x = self.blocks(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
#---------------------------------------------#
# 最后经过一个head得到最终的输出
#---------------------------------------------#
x = self.head(x)
return x