### 实现视觉Transformer (ViT)
为了实现视觉Transformer模型,可以采用PyTorch框架来构建。下面是一个简化版的ViT实现案例:
```python
import torch
from torch import nn, optim
class PatchEmbedding(nn.Module):
""" 将图像分割成多个patch,并映射到指定维度 """
def __init__(self, img_size=224, patch_size=16, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) ** 2
self.patch_shape = (num_patches, patch_size * patch_size * 3)
self.projection = nn.Linear(self.patch_shape[-1], embed_dim)
def forward(self, x):
B, C, H, W = x.shape
patches = x.unfold(2, self.patch_shape[0], self.patch_shape[0]).unfold(3, self.patch_shape[0], self.patch_shape[0])
patches = patches.contiguous().view(B, -1, self.patch_shape[-1])
return self.projection(patches)
class MultiHeadSelfAttention(nn.Module):
""" 多头自注意力机制 """
def __init__(self, dim, heads=8):
super().__init__()
self.heads = heads
self.scale = (dim / heads) ** (-0.5)
self.qkv = nn.Linear(dim, dim * 3)
self.out_proj = nn.Linear(dim, dim)
def forward(self, x):
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.reshape(t.shape[0], t.shape[1], self.heads, -1), qkv)
attn_scores = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
attn_probs = torch.softmax(attn_scores, dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn_probs, v)
out = out.reshape(out.shape[0], out.shape[1], -1)
return self.out_proj(out)
class EncoderBlock(nn.Module):
""" 编码器层:多头自注意 + 前馈神经网络 """
def __init__(self, dim, mlp_ratio=4., drop_rate=0.1, heads=8):
super().__init__()
hidden_features = int(mlp_ratio * dim)
self.attn_norm = nn.LayerNorm(dim)
self.mhsa = MultiHeadSelfAttention(dim, heads=heads)
self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
self.ffn_norm = nn.LayerNorm(dim)
self.dense1 = nn.Linear(dim, hidden_features)
self.act_fn = nn.GELU()
self.dense2 = nn.Linear(hidden_features, dim)
def forward(self, x):
h = x
x = self.attn_norm(x)
x = self.mhsa(x)
x = h + self.drop_path(x)
h = x
x = self.ffn_norm(x)
x = self.dense1(x)
x = self.act_fn(x)
x = self.dense2(x)
x = h + self.drop_path(x)
return x
class VisionTransformer(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
drop_rate=0.1,
num_classes=1000):
super().__init__()
self.patch_embedding = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.positional_encoding = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2+1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
encoder_layers = []
for _ in range(depth):
layer = EncoderBlock(
dim=embed_dim,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
heads=num_heads
)
encoder_layers.append(layer)
self.encoder_blocks = nn.Sequential(*encoder_layers)
self.norm = nn.LayerNorm(embed_dim)
self.fc_head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = self.patch_embedding(x)
x = torch.cat((cls_tokens, x), dim=1)
x += self.positional_encoding
x = self.encoder_blocks(x)
x = self.norm(x[:, 0])
logits = self.fc_head(x)
return logits
```
上述代码定义了一个基本版本的视觉Transformer架构,包括了Patch Embedding模块、Multi-head Self Attention以及Encoder Block等核心组件[^1]。