在图像分类领域,卷积神经网络(CNNs)长期占据主导地位,因其具备平移不变性和局部受限感受野等归纳偏置。然而,Transformer的出现为图像分类带来了新的思路。本文将详细探讨Transformer架构在图像分类中的微调,即Vision Transformer(ViT)的工作原理、重要细节以及具体实现。
ViT架构简述
-
图像分块与嵌入
- 首先,将图像分割成多个图像块(patches),这些图像块类似于文本序列中的单词。例如,对于一个输入图像(\mathbf(x) \in R^{H \times W \times C}) ,设定patch大小为 (p),我们可以将其转换为 (N) 个图像块 (\mathbf(x)_p \in R^{N \times (P^{2} C)}) ,其中 (N=\frac{H W}{P^{2}})。这里的 (N) 类似于句子中单词的数量,代表序列长度。
- 以 ( [16,16,3]) 的图像块为例,会被展平为 (16\times16\times3) 的向量。然后通过一个线性变换层将这些图像块转换为低维线性嵌入,其输入是 (P^{2} C) 维的向量,输出为 (D) 维。
-
位置嵌入
- 尽管Transformer架构本身是排列不变的,但为了让模型能够处理图像中的空间信息,需要添加位置嵌入(positional embeddings)。有趣的是,研究发现多种位置嵌入方案在实际应用中并没有显著差异。这可能是因为Transformer编码器在图像块级别上操作,相比于学习整个图像高度和宽度的位置关系,学习图像块之间的位置关系相对容易。
- 例如,想象一下解一个100片的拼图(图像块)和5000片的拼图(像素),显然前者更容易理解各部分之间的关系。在低维线性投影后,会添加一个可训练的位置嵌入到图像块表示中。训练后的位置嵌入呈现出一定的二维结构,并且行(列)之间的模式具有相似的表示。
-
标准Transformer编码器
- 将经过位置嵌入的图像块序列作为输入,送入标准的Transformer编码器。这里的编码器块与Vaswani等人在2017年提出的原始Transformer编码器相同,只是块的数量有所变化。为了证明在更多数据下可以训练更大的ViT变体,研究人员提出了3种模型。
- 模型中的 “Heads” 指多头注意力(multi - head attention),“MLP size” 指图中的蓝色模块,MLP即多层感知机,实际上是由一系列线性变换层组成。隐藏层大小 (D) 在各层中保持固定,这样可以使用短残差跳跃连接。
-
分类层
- ViT架构中没有解码器,而是在最后添加一个额外的线性层,称为MLP head,用于最终的分类任务。
重要细节
- 数据需求
- ViT在训练时对数据量有较高要求。如果在超过1400万张图像的数据集上训练,它能够接近或超越当前最先进的CNNs。若数据量不足,使用ResNets或EfficientNets可能是更好的选择。
- ViT通常先在大型数据集上进行预训练,然后在小型数据集上进行微调。微调时,需要丢弃预训练模型的预测头(MLP head),并附加一个新的 (D \times K) 线性层,其中 (K) 是小型数据集的类别数。
- 分辨率调整
- 研究人员发现,在微调时使用比预训练更高的分辨率效果更好。为了在更高分辨率下进行微调,需要对预训练的位置嵌入进行二维插值。这是因为位置嵌入是通过可训练的线性层建模的。
ViT中的关键发现
- 早期层特征可视化
- 在卷积神经网络早期,人们常对早期层进行可视化,因为训练良好的网络通常会显示出漂亮且平滑的滤波器。通过主成分分析(PCA)对ViT早期层进行可视化发现,早期层表示可能共享相似的特征。
- 非局部交互距离
- 对于图像块大小 (P),在ViT中从第一层开始,学习到的非局部交互距离最大可达 (P \times P),在本文例子中为128。而在传统卷积中,若不使用扩张卷积,感受野是线性增加的,要达到128的感受野需要较多的卷积层。
- 研究还发现,注意力距离会随着网络深度增加,类似于局部操作的感受野。同时,在低层也存在注意力距离始终较小的注意力头。为了验证高度局部化注意力头的想法,研究人员对在Transformer前应用ResNet的混合模型进行了实验,发现局部化注意力头减少,这表明其可能与CNNs早期卷积层具有相似的功能。
- 注意力距离的计算
- 注意力距离的计算方式为查询像素与图像块其余部分之间的平均距离乘以注意力权重。研究人员使用128个示例图像并对结果进行平均。例如,如果一个像素距离为20像素,注意力权重为0.5,则距离为10。
ViT的实现
下面是基于PyTorch和einops库的ViT实现代码框架:
import torch
import torch.nn as nn
from einops import rearrange
from self_attention_cv import TransformerEncoder
class ViT(nn.Module):
def __init__(self, *, img_dim, in_channels = 3, patch_dim = 16, num_classes = 10, dim = 512, blocks = 6, heads = 4, dim_linear_block = 1024, dim_head = None, dropout = 0, transformer = None, classification = True):
"""
Args:
img_dim: 图像空间大小
in_channels: 图像通道数
patch_dim: 期望的图像块维度
num_classes: 分类任务类别数
dim: 用于多头自注意力(MHSA)中投影图像块的线性层维度
blocks: Transformer块的数量
heads: 头的数量
dim_linear_block: Transformer线性块的内部维度
dim_head: 头的维度,如果未定义则默认为dim/heads
dropout: 位置嵌入和Transformer的 dropout
transformer: 若要提供其他Transformer实现
classification: 是否创建额外的CLS标记
"""
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks = blocks, heads = heads, dim_head = self.dim_head, dim_linear_block = dim_linear_block, dropout = dropout)
else:
self.transformer = transformer
def expand_cls_to_batch(self, batch):
"""
Args:
batch: 批次大小
Returns:
扩展到批次大小的CLS标记
"""
return self.cls_token.expand([batch, -1, -1])
def forward(self, img, mask = None):
batch_size = img.shape[0]
img_patches = rearrange(img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', patch_x = self.p, patch_y = self.p)
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat((self.expand_cls_to_batch(batch_size), img_patches), dim = 1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
y = self.transformer(patch_embeddings, mask)
if self.classification:
return self.mlp_head(y[:, 0, :])
else:
return y
总结
ViT的关键在于将图像分类问题转化为序列问题,使用图像块作为标记,并通过Transformer进行处理。这种方法听起来简单,但需要大量的数据支持。遗憾的是,Google拥有预训练数据集,导致结果难以重现,并且即使能够重现,也需要足够的计算能力。ViT为图像分类提供了一种新的视角和方法,随着研究的深入,相信在未来会有更多基于此的改进和创新。