一、前言
Vision Transformer (ViT) 是一种革命性的模型,它将传统的Transformer结构应用于图像处理领域。与传统的卷积神经网络(CNN)相比,ViT通过将图像切割成一系列小块并将其视为序列数据来处理,使用Transformer的强大自注意力机制直接捕捉全局依赖关系。这种方法不仅提高了处理效率和效果,也为理解图像数据提供了全新的视角。本博客将详细解释ViT的关键技术细节和其数学表示,并提供了代码实现以及源码链接,使读者能够更深入地探索和实践这一先进技术。
二、ViT算法的原理
1. 图像分块和嵌入
Vision Transformer (ViT) 中的图像分块和嵌入详解
在Vision Transformer(ViT)中,“图像分块和嵌入”是处理输入图像的首要步骤,对整体模型性能具有决定性影响。此过程涵盖两个主要步骤:将图像切分为小块和将这些块转换为模型可处理的嵌入向量。以下详细解释这两个步骤:
- 图像分块(Patching)
- 目的:将图像转换为序列形式以适应原本处理序列数据的Transformer结构,类似于文本处理中的单词序列化。
- 操作:输入图像尺寸为 H × W × C H \times W \times C H×W×C(其中 H H H为高度, W W W为宽度, C C C为颜色通道数,如RGB的3通道)。图像被均匀切分成多个 P × P P\times P P×P的小块。
- 结果:按此方式分块后,会生成 H P × W P \frac{H}{P} \times \frac{W}{P} PH×PW个图像块,每个块成为输入序列的一个元素。
- 图像块嵌入
- 目的:将原始像素数据的图像块转换为模型可有效处理的格式。
- 操作:每个图像块首先被展平(如果块大小为 P × P P \times P P×P,且有 C C C个通道,则展平后的大小为 P × P × C P \times P \times C P×P×C)。展平后的图像块通过线性层(如全连接层或乘法矩阵 E \mathbf{E} E)映射到一个固定维度的向量。
- 嵌入层:这个线性层可视为将每个展平的图像块从像素空间转换到嵌入空间的矩阵 E \mathbf{E} E。
- 位置编码:由于Transformer结构本身不处理输入数据的顺序信息,每个图像块还需添加位置编码(position embeddings),这些编码在训练中可学习,帮助模型理解各块在原图中的位置。
数学表示
整个嵌入过程可以表示为:
z
0
=
[
x
class
;
x
p
1
E
;
x
p
2
E
;
⋯
;
x
p
N
E
]
+
E
p
o
s
,
E
∈
R
(
P
2
⋅
C
)
×
D
,
E
p
o
s
∈
R
(
N
+
1
)
×
D
\mathbf{z}_0=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_p^1 \mathbf{E} ; \mathbf{x}_p^2 \mathbf{E} ; \cdots ; \mathbf{x}_p^N \mathbf{E}\right]+\mathbf{E}_{p o s}, \quad \mathbf{E} \in \mathbb{R}^{\left(P^2 \cdot C\right) \times D}, \mathbf{E}_{p o s} \in \mathbb{R}^{(N+1) \times D}
z0=[xclass ;xp1E;xp2E;⋯;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D
其中:
- x p i \mathbf{x}_p^i xpi表示第 i i i个图像块的展平向量。
- E \mathbf{E} E是转换矩阵,将图像块从像素空间映射到嵌入空间。
- x class \mathbf{x}_\text{class} xclass是一个额外的“分类”嵌入,用于最终的图像分类任务。
- E pos \mathbf{E}_{\text{pos}} Epos是加到所有图像块嵌入上的位置编码向量。
这种方法将图像处理问题转化为序列处理问题,使得为文本设计的Transformer架构可以有效地应用于图像分析任务。
2. Transformer编码器
(1) 多头自注意力机制(Multi-Head Self-Attention)
功能
自注意力机制使模型在处理序列的每个元素时考虑到序列中的所有其他元素,从而理解数据中的复杂依赖关系。在多头自注意力中,此机制被分成多个“头”,每个头从不同角度学习输入数据的表示,增强模型表达能力。
公式
多头自注意力可以表达为:
MultiHead
(
Q
,
K
,
V
)
=
Concat
(
head
1
,
…
,
head
h
)
W
O
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
MultiHead(Q,K,V)=Concat(head1,…,headh)WO
其中每个头的计算为:
head
i
=
Attention
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
headi=Attention(QWiQ,KWiK,VWiV)
自注意力计算公式为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dkQKT)V
这里,
Q
,
K
,
V
Q, K, V
Q,K,V是查询(Query)、键(Key)、值(Value),
W
i
Q
,
W
i
K
,
W
i
V
,
W
O
W_i^Q, W_i^K, W_i^V, W^O
WiQ,WiK,WiV,WO是可学习的参数矩阵,
d
k
d_k
dk是键向量的维数。
(2) 残差连接和层归一化(Layer Normalization)
功能
每个自注意力和前馈网络的输出通过残差连接后进行层归一化。残差连接帮助防止深层网络训练中的梯度消失问题,而层归一化则使训练更稳定。
公式
残差连接和层归一化的操作可表示为:
output
=
LayerNorm
(
x
+
Sublayer
(
x
)
)
\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))
output=LayerNorm(x+Sublayer(x))
其中
Sublayer
(
x
)
\text{Sublayer}(x)
Sublayer(x)是子层如自注意力或前馈网络的输出。
(3) 前馈网络(Position-wise Feed-Forward Networks)
功能
每个编码器层包含一个前馈网络,该网络是逐位置操作的,对序列的每个位置应用相同的全连接层,增加模型的非线性表示能力。
公式
前馈网络通常包括两个线性变换,中间插入ReLU激活函数:
FFN
(
x
)
=
max
(
0
,
x
W
1
+
b
1
)
W
2
+
b
2
\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
FFN(x)=max(0,xW1+b1)W2+b2
其中
W
1
,
W
2
,
b
1
,
b
2
W_1, W_2, b_1, b_2
W1,W2,b1,b2是网络参数。
通过这些组件的协同作用,Transformer编码器有效处理序列数据,捕捉长距离依赖,并通过自注意力机制自适应地关注输入序列中的重要部分。每个编码器层的输出都作为下一层的输入,共同贡献于最终的序列表示。
3. 多层Transformer编码器层
在探讨Vision Transformer的架构时,我们不可忽视的一个重要部分是数据如何在模型中逐层传递和处理。这个过程是通过一系列编码器层实现的,每一层都对数据进行转换和细化,以提取更深层次的特征。
下面是根据文献的描述,用比较简洁的公式来描述Transformer编码器层中的每一层都做了什么。
- 自注意力和残差连接
每层首先进行自注意力操作,以关注输入中的关键部分:
z ℓ ′ = MSA ( LN ( z ℓ − 1 ) ) + z ℓ − 1 , ℓ = 1 … L \mathbf{z}_{\ell}^{\prime} = \operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right) + \mathbf{z}_{\ell-1}, \quad \ell=1 \ldots L zℓ′=MSA(LN(zℓ−1))+zℓ−1,ℓ=1…L
- 前馈网络和残差连接
自注意力输出进一步通过前馈网络处理:
z ℓ = MLP ( LN ( z ℓ ′ ) ) + z ℓ ′ , ℓ = 1 … L \mathbf{z}_{\ell} = \operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right) + \mathbf{z}_{\ell}^{\prime}, \quad \ell=1 \ldots L zℓ=MLP(LN(zℓ′))+zℓ′,ℓ=1…L
- 输出处理
最后一层的输出用于下游任务,如分类:
y = LN ( z L 0 ) \mathbf{y} = \operatorname{LN}\left(\mathbf{z}_L^0\right) y=LN(zL0)
4. 分类头
在Vision Transformer (ViT)和其他基于Transformer的模型中,分类头(Classification Head)是模型架构的关键部分,负责将编码器层的输出转换为最终的分类预测。Transformer的输出通过分类头进行处理以预测图像类别:
output
=
softmax
(
z
L
0
W
class
)
,
\text{output} = \text{softmax}(\mathbf{z}_L^0 W_\text{class}),
output=softmax(zL0Wclass),
这里
z
L
0
\mathbf{z}_L^0
zL0是经过所有Transformer层处理后的CLS标记的向量,
W
class
W_\text{class}
Wclass是分类头的权重矩阵。
分类头作为连接编码器输出与最终类别预测的桥梁,在Vision Transformer中通过线性变换和softmax激活函数,将编码的图像特征有效转换为类别概率,完成分类任务。
三、掩码在Transformer模型中的应用
在Transformer模型中,掩码(mask)扮演着至关重要的角色,尤其是在处理序列数据时,如文本或时间序列。掩码主要用于两个目的:屏蔽无效的输入或填充值以及防止模型在生成输出时提前“窥视”未来的信息。
1. 填充掩码
在处理不等长的序列数据时,通常需要在较短的序列后添加填充(padding),以使它们与批次中最长的序列等长。这些填充值是人为添加的,不应影响模型的学习和预测。因此,我们使用填充掩码来指示模型在计算自注意力时忽略这些填充位置。
填充掩码矩阵实例:
假设一个序列长度为5,其中最后两个元素是填充的:
[
1
,
1
,
1
,
0
,
0
]
[1,1,1,0,0]
[1,1,1,0,0]
2. 未来信息掩码
在自然语言处理的文本生成任务中,如机器翻译或文本摘要,保证模型在预测当前词或字符时不能使用到未来的信息至关重要。为此,Transformer的解码器使用所谓的未来信息掩码,这通常是一个下三角形矩阵。这种掩码确保在计算每个输出时,只能访问到当前和之前的输入,而未来的输入则被屏蔽。这种技术是确保序列生成任务中信息正确流动的关键。
未来信息掩码矩阵实例:
[
θ
−
∞
−
∞
−
∞
−
∞
θ
θ
−
∞
−
∞
−
∞
θ
θ
θ
−
∞
−
∞
θ
θ
θ
θ
−
∞
θ
θ
θ
θ
θ
]
\begin{array}{cccc} \left[\begin{array}{ccccc} \theta & -\infty & -\infty & -\infty & -\infty \\ \theta & \theta & -\infty & -\infty & -\infty \\ \theta & \theta & \theta & -\infty & -\infty \\ \theta & \theta & \theta & \theta & -\infty \\ \theta & \theta & \theta & \theta & \theta \\ \end{array}\right] \end{array}
θθθθθ−∞θθθθ−∞−∞θθθ−∞−∞−∞θθ−∞−∞−∞−∞θ
在实践中,这通过将填充位置的注意力得分设置为负无穷(-inf
),进而在应用softmax时将这些位置的权重变为零来实现。
3. 组合使用掩码
在解码器中,填充掩码和未来信息掩码往往需要组合使用。这种组合掩码既屏蔽了填充值也避免了信息泄露,从而使模型能够更准确地学习如何基于前文生成文本。
组合掩码矩阵实例:
[
θ
−
∞
−
∞
−
∞
−
∞
θ
θ
−
∞
−
∞
−
∞
θ
θ
θ
−
∞
−
∞
θ
θ
θ
−
∞
−
∞
θ
θ
θ
−
∞
−
∞
]
\begin{array}{cccc} \left[\begin{array}{ccccc} \theta & -\infty & -\infty & -\infty & -\infty \\ \theta & \theta & -\infty & -\infty & -\infty \\ \theta & \theta & \theta & -\infty & -\infty \\ \theta & \theta & \theta & -\infty & -\infty \\ \theta & \theta & \theta & -\infty & -\infty \\ \end{array}\right] \end{array}
θθθθθ−∞θθθθ−∞−∞θθθ−∞−∞−∞−∞−∞−∞−∞−∞−∞−∞
掩码不仅提高了模型的效率(通过忽略不必要的计算),还增强了模型的安全性和可靠性,使其在预测时只依赖于合适的上下文信息。理解和正确实现掩码是优化Transformer模型性能的关键。
四、ViT算法的代码
1. 简单实现
提供一个最简单的Vision Transformer (ViT) 模型的PyTorch实现涉及构建模型的基本组件,如图像分块、位置编码、Transformer编码器层和分类头。以下是一个简化的版本,展示了如何在PyTorch中实现基本的ViT结构:
import torch
from torch import nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim):
"""
初始化PatchEmbedding模块。
参数:
img_size (int): 输入图像的边长(假设图像是正方形)。
patch_size (int): 每个图像块的边长。
in_channels (int): 输入图像的通道数。
embed_dim (int): 嵌入向量的维度。
"""
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
"""
前向传播函数,从图像中提取嵌入的图像块。
参数:
x (Tensor): 形状为 [B, C, H, W] 的输入张量。
返回:
Tensor: 嵌入后的图像块,形状为 [B, N, E]。
"""
x = self.proj(x) # 使用卷积提取块并投影到嵌入维度
x = x.flatten(2) # 将高度和宽度维度合并
x = x.transpose(1, 2) # 转换维度为 [B, N, E]
return x
class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
"""
初始化Transformer编码器层。
参数:
embed_dim (int): 嵌入维度。
num_heads (int): 注意力机制中的头数。
ff_dim (int): 前馈网络中间层的维度。
dropout (float): Dropout概率。
"""
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim),
)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
前向传播函数,处理输入张量通过编码器层。
参数:
x (Tensor): 形状为 [B, N, E] 的输入张量。
返回:
Tensor: 形状为 [B, N, E] 的输出张量。
"""
# 自注意力和层归一化
att, _ = self.attention(x, x, x)
x = x + self.dropout(att)
x = self.norm1(x)
# 前馈网络和层归一化
ffn = self.ffn(x)
x = x + self.dropout(ffn)
x = self.norm2(x)
return x
class ViT(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim, num_heads, ff_dim, num_layers, num_classes):
"""
初始化Vision Transformer (ViT) 模型。
参数:
img_size (int): 输入图像的边长(假设图像是正方形)。
patch_size (int): 每个图像块的边长。
in_channels (int): 输入图像的通道数。
embed_dim (int): 嵌入维度。
num_heads (int): 注意力机制中的头数。
ff_dim (int): 前馈网络中间层的维度。
num_layers (int): Transformer编码器层数。
num_classes (int): 分类任务的类别数。
"""
super().__init__()
self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 分类令牌
self.positional_embedding = nn.Parameter(torch.randn(1, 1 + self.patch_embedding.n_patches, embed_dim)) # 位置嵌入
self.transformer = nn.ModuleList([
TransformerEncoder(embed_dim, num_heads,ff_dim) for _ in range(num_layers) # 构建多层Transformer编码器
])
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim), # 层归一化
nn.Linear(embed_dim, num_classes) # 最后的分类层
)
def forward(self, x):
"""
前向传播函数,处理输入图像通过ViT模型。
参数:
x (Tensor): 形状为 [B, C, H, W] 的输入图像张量。
返回:
Tensor: 最终的分类输出,形状为 [B, num_classes]。
"""
x = self.patch_embedding(x) # 转换图像到一个序列的嵌入表示
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # 复制分类标记到每个样本
x = torch.cat((cls_tokens, x), dim=1) # 将分类标记和嵌入的图像块合并
x += self.positional_embedding # 添加位置嵌入
for encoder in self.transformer:
x = encoder(x) # 逐层传递通过Transformer编码器
cls_token_final = x[:, 0] # 提取分类标记对应的输出
return self.mlp_head(cls_token_final) # 通过MLP获取最终的分类结果
# 示例用法
model = ViT(img_size=256, patch_size=16, in_channels=3, embed_dim=768, num_heads=12, ff_dim=3072, num_layers=12, num_classes=1000)
img = torch.randn(2, 3, 256, 256) # 随机生成模拟输入
logits = model(img) # 调用模型进行前向传播
说明:
- PatchEmbedding 类将图像转换为一个序列化的嵌入向量。
- TransformerEncoder 类实现了标准的Transformer编码器层,包括多头自注意力和前馈网络。
- ViT 类组装了整个Vision Transformer模型的结构,包括添加分类标记(CLS token)和位置编码,然后通过多层Transformer编码器处理,最后通过一个线性层进行分类。
2. Attention的手动实现
上面的Attention部分是通过调用pytorch实现的nn.MultiheadAttention
。这里我们手动实现一个Attention。
import torch
import torch.nn as nn
import torch.nn.functional as F
def generate_square_subsequent_mask(size):
"""生成一个下三角掩码,防止位置i获取到位置j>i的信息。
参数:
size: 序列的长度,即掩码的尺寸将为 size x size。
返回:
mask: 下三角形矩阵,对角线及以下的元素为1,对角线以上的元素为0。
"""
mask = torch.tril(torch.ones(size, size))
return mask
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
"""
计算缩放点积自注意力。
参数:
query: 查询张量,形状为 [batch_size, num_heads, seq_length, dim_per_head]
key: 键张量,形状为 [batch_size, num_heads, seq_length, dim_per_head]
value: 值张量,形状为 [batch_size, num_heads, seq_length, dim_per_head]
mask: 可选的掩码张量,用于屏蔽某些位置。
dropout: 可选的dropout模块,用于应用到注意力分数上。
返回:
加权值张量和注意力张量,形状均为 [batch_size, num_heads, seq_length, dim_per_head]
"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
"""
多头自注意力模块的初始化。
参数:
embed_dim: 嵌入的维度。
num_heads: 头的数量。
dropout: Dropout比率。
"""
super().__init__()
self.num_heads = num_heads
self.dim_per_head = embed_dim // num_heads
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask=None):
"""
前向传播方法。
参数:
query: 查询张量。
key: 键张量。
value: 值张量。
mask: 掩码张量。
返回:
注意力机制处理后的输出张量。
"""
batch_size = query.size(0)
# 线性投影
query = self.query(query).view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
key = self.key(key).view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
value = self.value(value).view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
# 应用自注意力机制
x, attn = scaled_dot_product_attention(query, key, value, mask, self.dropout)
# "Concat"使用view并应用最终的线性层
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.dim_per_head)
return self.out_proj(x)
# 示例用法
# 定义嵌入维度、头的数量和dropout比率
embed_dim = 512 # 嵌入的维度,即每个词向量的维度
num_heads = 8 # 多头注意力中头的数量
dropout = 0.1 # 在注意力权重上应用的dropout比率
# 创建多头注意力模块的实例
attn = MultiheadAttention(embed_dim, num_heads, dropout)
# 定义序列长度
seq_length = 10 # 序列长度,即每个输入序列中元素的数量
# 生成下三角掩码,防止在生成当前输出时看到未来的信息
mask = generate_square_subsequent_mask(seq_length)
# 生成随机的查询、键、值数据
query = torch.rand(5, seq_length, embed_dim) # 生成随机查询数据,形状为 [batch_size, seq_length, embed_dim]
key = value = query # 在自注意力中,键和值设为与查询相同,这是自注意力的典型设置
# 使用多头注意力模块处理数据,传入掩码以防止看到未来信息
output = attn(query, key, value, mask=mask)
3. Vision Transformer (ViT) 源代码资源
Vision Transformer(ViT)是一种创新的神经网络架构,它将Transformer模型应用于视觉任务。如果你想探索或使用ViT的源代码,以下是一些可以找到实现的优秀资源:
官方实现
Google Research 团队提供了 Vision Transformer 的官方实现,这是理解和学习这种模型的极佳起点。
这个GitHub仓库包含了实现ViT所需的全部代码及预训练模型的链接。
社区实现
社区开发者也提供了多个易于理解和集成的ViT实现,以下是一些受欢迎的选项:
- lucidrains/vit-pytorch: 由社区维护的PyTorch实现,特别适合那些希望快速集成ViT到自己项目的开发者。
- rwightman/pytorch-image-models(也称为 timm 库): 包含多种图像模型的PyTorch实现库,包括多种Vision Transformer模型及其变体。
五、总结
Vision Transformer (ViT) 在图像处理领域代表了一项重大的技术突破。通过把图像分块并作为序列处理,ViT利用了Transformer架构的强大能力,以独特的方式捕捉了图像的全局信息并进行了有效分类。本博客不仅介绍了ViT的算法原理,还提供了其代码实现和源码访问链接,这不仅展示了自注意力机制在非NLP领域的巨大潜力,也为读者提供了实际应用这些技术的途径。随着技术的不断发展,ViT及其变体有望推动未来图像处理技术的进一步创新。
参考文献
AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE