以下是一个简化的ViT(Vision Transformer)模型的实现代码示例。ViT模型用于图像分类任务,通过将图像分割成小块(patches),然后将每个小块视为一个序列输入到Transformer模型中。
import torch
import torch.nn as nn
from einops import rearrange
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.pos_embedding = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x) # (B, emb_size, H/patch_size, W/patch_size)
x = rearrange(x, 'b e (h) (w) -> b (h w) e') # (B, N, emb_size)
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, emb_size)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, emb_size)
x = x + self.pos_embedding # (B, N+1, emb_size)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(self, emb_size=768, num_heads=12, forward_expansion=4, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(emb_size)
self.mha = nn.MultiheadAttention(emb_size, num_heads, dropout=dropout)
self.ln2 = nn.LayerNorm(emb_size)
self.ffn = nn.Sequential(
nn.Linear(emb_size, forward_expansion * emb_size),
nn.GELU(),
nn.Linear(forward_expansion * emb_size, emb_size),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.mha(self.ln1(x), self.ln1(x), self.ln1(x))[0]
x = x + self.ffn(self.ln2(x))
return x
class TransformerEncoder(nn.Module):
def __init__(self, depth=12, **kwargs):
super().__init__()
self.layers = nn.ModuleList([TransformerEncoderLayer(**kwargs) for _ in range(depth)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class ViT(nn.Module):
def __init__(self, num_classes=1000, img_size=224, patch_size=16, in_channels=3, emb_size=768, depth=12, num_heads=12, forward_expansion=4, dropout=0.1):
super().__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
self.transformer = TransformerEncoder(depth, emb_size=emb_size, num_heads=num_heads, forward_expansion=forward_expansion, dropout=dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(emb_size),
nn.Linear(emb_size, num_classes)
)
def forward(self, x):
x = self.patch_embedding(x)
x = self.transformer(x)
x = x[:, 0] # 取出分类token
x = self.mlp_head(x)
return x
# 测试模型
if __name__ == '__main__':
model = ViT()
img = torch.randn(1, 3, 224, 224) # 单张图像
preds = model(img)
print(preds.shape) # 输出形状
代码说明:
-
PatchEmbedding:
- 将输入图像分割成小块,并通过卷积层将其投影到嵌入空间中。还包含分类令牌和位置嵌入。
-
TransformerEncoderLayer:
- 每个Transformer编码器层包括一个多头自注意力机制和一个前馈神经网络。
-
TransformerEncoder:
- 由多个编码器层堆叠而成。
-
ViT模型:
- 包括图像分块嵌入、Transformer编码器和分类头部。
-
测试模型:
- 测试模型的前向传播,输入单张图像并输出预测结果。
这个简化的ViT模型可以用于图像分类任务,通过合适的训练可以在各种数据集上取得优异的性能。