Vision Transformer (ViT) 计算复杂度详解
1. 输入与 Patch Embedding 的复杂度
1.1 输入图像
假设输入图像的尺寸为:
输入图像 = H × W × C \text{输入图像} = H \times W \times C 输入图像=H×W×C
- H H H 和 W W W:图像的高度和宽度。
- C C C:图像的通道数(通常为 3,代表 RGB 图像)。
1.2 Patch 分割
将图像分割成 P × P P \times P P×P 的 Patch,每个 Patch 被展平为一维向量,作为 Transformer 的输入。
-
Patch 的数量:
N = H ⋅ W P 2 N = \frac{H \cdot W}{P^2} N=P2H⋅W
N N N 是分割得到的 Patch 数。 -
每个 Patch 的维度:
d patch = P ⋅ P ⋅ C d_{\text{patch}} = P \cdot P \cdot C dpatch=P⋅P⋅C
即每个 Patch 包含 P × P P \times P P×P 个像素,乘以通道数 C C C。
1.3 Patch Embedding
每个 Patch 经过一个线性层投影到固定维度 d model d_{\text{model}} dmodel:
-
线性变换的权重矩阵大小:
W embed ∈ R d patch × d model W_{\text{embed}} \in \mathbb{R}^{d_{\text{patch}} \times d_{\text{model}}} Wembed∈Rdpatch×dmodel -
线性变换的计算复杂度:
对所有 N N N 个 Patch 执行线性变换:
O embedding = N ⋅ d patch ⋅ d model O_{\text{embedding}} = N \cdot d_{\text{patch}} \cdot d_{\text{model}} Oembedding=N⋅dpatch⋅dmodel
2. Transformer Encoder 的复杂度
Transformer Encoder 是 ViT 的核心组件,其复杂度主要来源于 多头自注意力机制 和 前馈网络。
2.1 多头自注意力机制
2.1.1 Query、Key、Value 计算
对于输入特征 X ∈ R N × d model X \in \mathbb{R}^{N \times d_{\text{model}}} X∈RN×dmodel:
-
计算 Query、Key 和 Value:
Q = X W Q , K = X W K , V = X W V Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV
其中 W Q , W K , W V ∈ R d model × d model