-
CNN 的局限性:传统的 CNN 通过局部卷积核提取特征,虽然可以通过堆叠多层卷积扩大感受野,但仍然依赖于局部信息的逐步聚合,难以直接建模全局依赖关系。
-
ViT 的优势:ViT 使用自注意力机制(Self-Attention),能够直接捕捉图像中所有 patch(图像块)之间的全局关系。这种全局建模能力在处理需要长距离依赖的任务(如图像分类、目标检测)时表现更好。
全流程
-
图像预处理+分块
-
图像尺寸标准化,如(224*224)
-
分块操作,将图像分为等大小非重叠的块(Patch)
-
-
块嵌入(Patch embedding)
-
展平成一维,然后过一个Linear映射成d_model,或者直接卷积操也能达到相同效果。
-
类别标记,[class Toekn] 在块嵌入序列前添加可学习的类别标记,用于最终的分类任务 【这个比如说分成了196个bacth,那么此时的张量形状应该是(batch_size,patch_num,patchH*patchW)】,这个加类别相当于patch_num+1,作为类别标记
-
-
位置编码,将位置嵌入添加到嵌入前边
-
Transformer Encoder处理
-
编码器堆叠,含类别标记的嵌入经过多个transformer Encoder处理,出来一个增强表示
-
transformer里边的MLP通常拓展4倍维度
-
-
分类输出:仅仅取序列中的类别标记作为高层特征() 输入MLP Head,输出结果
图像分patch方法
认为一个图像可以分成若干个相等大小的块,比如3x224x224格式的图像,如果patch大小为14x14,则分为256个patch,这个时候每个patch展平就是长度为196的vector。
对patch做embedding
相当于对wordVector做embedding,之后一串embedding序列传递给transformer的Encoder就可以得到全局表示
class PatchEmbeddings(nn.Module):
def __init__(self,d_model,pacth_size,in_channels):
super().__init__()
self.conv=nn.Conv2d(in_channels,d_model,pacth_size,stride=pacth_size)
def forward(self,x):
x=self.conv(x)
batch_size,c,h,w=x.shape
x=x.permute(2,3,0,1)
x=x.view(h*w,batch_size,c)
# (h*w,batch_size,d_model)
return x
位置编码
import torch
class LearnedPositionalEmbeddings(nn.Module):
def __init__(self,d_model,max_len=5000):
super().__init__()
self.positional_encodings=nn.Parameter(torch.zeros(max_len,1,d_model),requires_grad=True);
def forward(self,x):
pe=self.positional_encodings[:x.shape[0]]
return x+pe
分类器
class ClassficationHead(nn.Module):
def __init__(self,d_model,n_hidden,n_classes):
super().__init__()
self.linear1=nn.Linear(d_model,n_hidden)
self.act=nn.ReLU()
self.linear2=nn.Linear(n_hidden,n_classes)
def forward(self,x):
x=self.act(self.linear1(x))
x=self.linear2(x)
return x
VIT总体
class VisionTransformer(nn.Module):
def __init__(self,transformer_layer,n_layers,pacth_emb,pos_emb,classification):
super().__init__()
self.pacth_emb=pacth_emb
self.pos_emb=pos_emb
self.classification=classification
self.transformer_layer = nn.ModuleList([transformer_layer for _ in range(n_layers)]) #n个transformer layer
self.cls_token_emb=nn.Parameter(torch.randn(1,1,transformer_layer.size),requires_grad=True)
self.ln=nn.LayerNorm([transformer_layer.size])
def forward(self,x):
x=self.pacth_emb(x)
cls_token_emb=self.cls_token_emb.expand(-1,x.shape[1],-1)
x=torch.cat([cls_token_emb,x])
x=self.pos_emb
for layer in self.transformer_layer:
x=layer(x=x,mask=None)
x=x[0]
x=self.ln(x)
x=self.classification(x)