ViT简介
ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型‘’简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究。
论文地址:An image is worth 16x16 words: Transformers for image recognition at scale
ViT通过将图像分成一系列的图块(patches),并将每个图块转换为向量表示作为输入序列。然后,这些向量将通过多层的Transformer编码器进行处理,其中包含了自注意力机制和前馈神经网络层。这样可以捕捉到图像中不同位置的上下文依赖关系。最后,通过对Transformer编码器输出进行分类或回归,可以完成特定的视觉任务。
为什么不能直接将transformer直接应用于图像处理中呢?
这是因为transformer本身是用来处理序列任务的(比如NLP),但图像是二维或三维的,像素之间存在一定的结构关系,如果单纯的将transformer之间应用于图像中,像素和像素之间需要一定的关联性,那么这个计算量是相当大的。
因此vit就诞生了。
下面将根据代码对该网络进行详细讲解!!!
一、ViT模型整体架构
Vit的模型结构如下图所示。
vit是将图像块应用于transformer。CNN是以滑窗的思想用卷积核在图像上进行卷积得到特征图。为了可以使图像仿照NLP的输入序列,我们可以先将图像分成块(patch),再将这些图像块进行平铺后输入到网络中(这样就变成了图像序列),然后通过transformer进行特征提取,最后再通过MLP对这些特征进行分类【其实就可以理解为在以往的CNN分类任务中,将backbone替换为transformer】。
由下图所示,将整体结构分为六个部分:
步骤1:将图片转换成patches序列
步骤2:将Patches铺平
步骤3、添加Position embedding
步骤4、添加class token
步骤5、输入Transformer Encoder
步骤6、分类
其中,每个步骤如何实施,进一步通过代码讲解:
二、代码实现详解
我们将ViT模型比做一个小的玩具,要去玩它。
我们所讲的是最基本,最简单,最通俗易懂的模型,只有懂了最基本的,才可以在基本的上面进行集成,进行添加,使之更加丰富。
其次,本文在网络讲解中,更多的不注重为什么,而是讲清楚怎么做。
2.1 ViT代码
首先,输入图片的纬度为(3, 224, 224),其中3为通道数(channels),224*224为图片的宽和高。
图像输入之后首先是需要进行patch分块,一般设置patch的尺寸为16*16,那么一共能生成(224/16)*(224/16)=196个patch块。这部分内容在代码中如何实现呢?
其实很简单,就是用一个卷积层就可以实现,其卷积核大小为patch size=16, 步长为patch size=16。
nn.Conv2d(channels, dim, kernel_size=patch_height, stride=patch_height)
接着,在forward中调用卷积进行维度变化与合并,最终得到输出维度为(1,196,768),1为batchsize。
# flatten: [B, C, H, W] -> [B, C, HW] 从指定维度开始展平(flatten)
# transpose: [B, C, HW] -> [B, HW, C]
x = self.to_patch_embedding(img).flatten(2).transpose(1,2) ## img 1 3 224 224 输出形状x : 1 196 768
之后我们还需要加上位置编码和类别token,前者使用直接相加的方法,后者使用concat的方法,所以加上类别token后,特征的维度变化为:(1,197,768)。
2.1.1 Class
Embedding
Class Embedding
有两种预测类别的方式:
- 使用 class token;
- 使用全局平均池化。
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
这两种方式都是可行的,更倾向于使用 class token
是因为想把原滋原味的 transformer 直接应用到 CV 领域。
2.1.2
Position Embedding
Position Embedding
将位置编码嵌入图像块,用于表达图像块在原图的位置信息。
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
位置编码随位置变化,即位置差别越大,位置编码差别越大。
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size) ## 224*224
patch_height, patch_width = pair(patch_size) ## 16 * 16
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
nn.Conv2d(channels, dim, kernel_size=patch_height, stride=patch_height),
# Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
# nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
# flatten: [B, C, H, W] -> [B, C, HW] 从指定维度开始展平(flatten)
# transpose: [B, C, HW] -> [B, HW, C]
x = self.to_patch_embedding(img).flatten(2).transpose(1,2) ## img 1 3 224 224 输出形状x : 1 196 768
b, n, _ = x.shape ##
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
2.2 transformer代码
Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由两个部分组成:多头注意力机制(attention)和前馈网络(FeedForward),并都先经过LayerNorm层。
下面分别是transformer和进行LayerNorm的代码。
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
2.3 Attention代码
Multi-Head Attention 多头注意力机制,来源于论文《Attention Is All You Need》。
接下来看多头注意力机制的实现,如下所示。
首先通过使用一个全连接层生成q,k,v的初始值:
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
然后使用reshape和维度调换来进行调整:
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
最后使用切片操作分别获得单独的Q,K,V:
q, k, v = qkv[0], qkv[1], qkv[2]
接下来就是transformer原始文章里提出的注意力机制的公式的实现了,公式如下:
其中,每个部分的维度变化都有详细的注释。
整体代码如下:
class Attention(nn.Module):
def __init__(self, dim, heads, dropout = 0.):
super().__init__()
# inner_dim = dim_head * heads
dim_head = dim // heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]796/8
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
# out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
2.4 FeedForward代码(MLP)
然后是MLP模块,也就是transformer原文中的前馈网络(feed forward),这一部分其实比较简单,没什么可讲的,就是两个全连接层加上dropout层实现,流程图如下:
代码如下:
class FeedForward(nn.Module): # MLP
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
总结
接下来就是最后的分类模块了,这一块比较简单,甚至可以只用一层全连接层来解决。
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
之前没有了解过VIT的小伙伴,这里需要提示一下,我们输入到MLP类别分类器中的特征只有类别token。经过N层transformer编码器处理后的特征的维度与输入前相同,均为[197,768],我们只使用列表切片的方式提取出类别token,维度为[1,768]。进行下一步的类别分类。
有小伙伴可能不理解,那不是其它的特征没有用到吗?浪费了是不是。其实不是,多头注意力机制可以让不同位置的特征进行全面交互,这里输出的类别token和之前输入的类别token早已发生了巨变,这种变化是由其它特征影响的。
最后提供一下,transformer模型的整体架构代码:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module): # MLP
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads, dropout = 0.):
super().__init__()
# inner_dim = dim_head * heads
dim_head = dim // heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
B, N, C = x.shape
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]796/8
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
# out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size) ## 224*224
patch_height, patch_width = pair(patch_size)## 16 * 16
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
nn.Conv2d(channels, dim, kernel_size=patch_height, stride=patch_height),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img).flatten(2).transpose(1,2) ## img 1 3 224 224 输出形状x : 1 196 1024
b, n, _ = x.shape ##
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
v = ViT(
image_size = 224,
patch_size = 16,
num_classes = 1000,
dim = 768,
depth = 6,
heads = 8,
mlp_dim = 768*4,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
ViT的意义何在
学了这个网络,我们还要明白为什么学,学了有什么意义。
下面这段话借鉴一位博主,链接在最后。
在工业界,人们的标注数据量和算力都是有限的,因此CNN可能还是首要选择。但是,ViT的出现,不仅是用模型效果来考量这么简单。
今天再来看这个模型,发现它的意义在于:
证明了一个统一框架在不同模态任务上的表现能力。在ViT之前,NLP的SOTA范式被认为是Transformer,而图像的SOTA范式依然是CNN。ViT出现后,证明了用NLP领域的SOTA模型一样能解图像领域的问题,同时在论文中通过丰富的实验,证明了ViT对CNN的替代能力,同时也论证了大规模+大模型在图像领域的涌现能力(论文中没有明确指出这是涌现能力,但通过实验展现了这种趋势)。这也为后续两年多模态任务的发展奠定了基石。
虽然ViT只是一个分类任务,但在它提出的几个月之后,立刻就有了用Transformer架构做检测(detection)和分割(segmentation)的模型。而不久之后,GPT式的无监督学习,也在CV届开始火热起来。
工业界上,对大部分企业来说,受到训练数据和算力的影响,预训练和微调一个ViT都是困难的,但是这不妨碍直接拿大厂训好的ViT特征做下游任务。同时,低成本的微调方案研究,在今天也层出不穷。长远来看,2年前的这个“庞然大物”,已经在逐步走进千家万户。
原文链接:https://blog.csdn.net/m0_37605642/article/details/133821025