目录
一、论文题目
二、背景与动机
在深度学习领域,卷积神经网络(CNN)长期以来一直占据着计算机视觉任务的主导地位,尤其在图像分类、目标检测等方面取得了显著成果。然而,CNN的设计基于局部感受野和权重共享等特性,对于全局信息的捕获存在一定的局限性。近年来,随着Transformer架构在自然语言处理(NLP)领域的巨大成功,研究者们开始思考如何将这种强大的序列建模能力引入到计算机视觉中。
三、卖点与创新
Vision Transformer的主要创新点在于将Transformer架构直接应用于图像。
全局注意力机制:ViT的最大创新之处在于其采用了Transformer中的自注意力机制,摒弃了卷积层,能够对输入图像进行全局分析,这使得模型可以更好地理解和学习图像中的长距离依赖关系,从而提升模型性能。
分块序列化处理:ViT将整个图像分割成固定大小的 patches,并将这些patches线性嵌入后形成一个序列输入到Transformer中。这种“图像变序列”的处理方式,使得ViT可以直接利用NLP领域的成熟技术,在图像识别上表现出强大的适应性和灵活性。
预训练与微调:ViT借鉴了BERT等模型的预训练-微调策略,通过大规模无标注数据集进行预训练,然后在特定任务上进行微调,这一策略大大提升了模型的泛化能力和迁移学习效果。
四、具体实现细节
网络架构
具体实现步骤
- 输入图片shape为[224, 224, 3] [224,224,3]
- 通过
Patch Embedding
,即一个卷积 + Flatten()
,生成shape为 [196,768] 的tokens- 进行
Class Embedding
,即torch.concat(tokens, cls_token)
cls_token为可训练参数,tokens的shape变化为: [196,768]→[197,768]Position Embedding
,加上位置编码tokens,即tokens = tokens + pos_tokens
pos_tokens为可训练参数,tokens的shape变化为: [197,768]→[197,768]- 通过
Dropout层
- 经过 若干层
Transformer Encoder
- 经过
LayerNorm层
,shape为 [197,768]- 提取
Class Token
所对应的输出
这里的实现为切片,对 [197,768] 进行切片,只需要提取出Class Token
对应的输出( [1,768] )即可- 通过
MLP Head
得到最终的输出
在自己训练集上使用,Pre-Logits
就不要了,MLP Head就是nn.Linear
- 通过
Softmax
得到概率输出
简易代码
import torch
from torch import nn
from torch.nn import functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_dim, heads):
super().__init__()
self.embed_dim = embed_dim
self.heads = heads
self.head_dim = embed_dim // heads
assert (
self.head_dim * heads == embed_dim
), "embed_dim must be divisible by heads"
# 分割嵌入向量到多个头部,每个头部有一个独立的线性层
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# 输出线性层将多头输出再次结合起来
self.fc_out = nn.Linear(heads * self.head_dim, embed_dim)
def forward(self, values, keys, queries, mask):
N = queries.shape[0] # 批量大小
value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
# 分割输入以适应多头注意力机制
# 维度从(N, seq_length, embed_dim)变为(N, seq_length, heads, head_dim)
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
# 将嵌入向量通过线性层并计算自注意力矩阵
# 注意力矩阵的维度为(N, heads, query_len, key_len)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# 使用“einsum”进行自注意力机制的运算
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_dim ** (1 / 2)), dim=3)
# 将注意力矩阵与值向量相乘并重新排列回原来的维度
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# 最终通过一个线性层得到输出
out = self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, heads, dropout, forward_expansion):
super().__init__()
self.attention = SelfAttention(embed_dim, heads)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
# 前馈网络使用两个线性层
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, forward_expansion*embed_dim),
nn.ReLU(),
nn.Linear(forward_expansion*embed_dim, embed_dim)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
# 添加skip connection, 然后进行层归一化和dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class ViT(nn.Module):
def __init__(self, img_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.1, emb_dropout=0.1):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.dim = dim
self.depth = depth
self.num_patches = (img_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
# 位置编码和CLS token初始化
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches+1, dim))
self.patch_to_embedding = nn.Linear(self.patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
# 创建多个Transformer块
self.transformer = nn.Sequential(
*[TransformerBlock(dim, heads, dropout, mlp_dim) for _ in range(depth)]
)
self.to_cls_token = nn.Identity()
# MLP头部用于分类任务
self.mlp_head = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.ReLU(),
nn.Linear(mlp_dim, num_classes)
)
def forward(self, x, mask=None):
N, C, H, W = x.shape # 图像维度
# 切割图像成patches并平铺
x = x.view(N, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1) # N, n_patches, patch_size, patch_size, C
x = x.flatten(2) # N, n_patches, patch_dim
x = x.flatten(1) # N, n_patches * patch_dim
patches = self.patch_to_embedding(x) # N, num_patches, dim
cls_tokens = self.cls_token.expand(N, -1, -1) # N, 1, dim
x = torch.cat((cls_tokens, patches), dim=1) # N, num_patches+1, dim
x += self.pos_embedding # 加上位置编码
x = self.dropout(x)
x = self.transformer(x, mask) # N, num_patches+1, dim
x = self.to_cls_token(x[:, 0]) # N, dim
return self.mlp_head(x) # N, num_classes
# 设置超参数 (根据你的数据和资源进行调整)
img_size = 256 # 图像大小
patch_size = 32 # patch大小
num_classes = 10 # 分类类别数量
channels = 3 # 图像通道数
dim = 1024 # 嵌入向量的维度
depth = 6 # Transformer块的数量
heads = 8 # 注意力头的数量
mlp_dim = 2048 # 前馈网络的隐藏层维度
dropout = 0.1 # dropout比率
emb_dropout = 0.1 # embedding dropout比率
# 实例化模型
vit = ViT(
img_size=img_size,
patch_size=patch_size,
num_classes=num_classes,
dim=dim,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
dropout=dropout,
emb_dropout=emb_dropout
)
# 创建一个假的图像张量 (批大小, 通道数, 高度, 宽度)
img = torch.randn(1, channels, img_size, img_size)
# 获得预测结果
preds = vit(img) # 维度为 (1, num_classes)
print(preds)