之前已经了解过一些ViT的东西了,加上这篇文章之前也看过,所以做个粗略的介绍,毕竟在目前很多基于它改的模型上,都基本都有比较详细和精准的介绍。
方法和途径
- 将原始图片: x ∈ R H × W × C \mathbf{x} \in \mathbb{R}^{H \times W \times C} x∈RH×W×C 拆成一串二维的图片块: x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)} xp∈RN×(P2⋅C),其中 ( H , W ) (H, W) (H,W) 是原图片的分辨率, C C C 是通道数量, ( P , P ) (P, P) (P,P) 是每个图片块的分辨率, N = H W / P 2 N = HW/P^2 N=HW/P2 是所有图片块的数量,同时也是Transformer的输入序列的长度。
- Transformer中使用的是一个长度为 D D D的固定向量,因此我们用一个可训练的线性映射,将这些图片块映射到 D D D维去。
- 和BERT的 [class] token 类似,我们在嵌入好的图片块之前加上一个可学习的参数 ( z 0 0 = x class ) \left(\mathbf{z}_{0}^{0}=\mathbf{x}_{\text {class }}\right) (z00=xclass ),用它在 Transformer encoder 的输出 ( z L 0 ) (z_L^0) (zL0) 当作图片的代表 y \mathbf{y} y. 在训练和测试过程中,都在 ( z L 0 ) (z_L^0) (zL0) 上加一个分类头,分类头在预训练时是由一个含有一层隐藏层的 MLP 实现的,在精调时是用一个单独线性层的MLP实现的。
- 位置嵌入也使用了
- LayerNorm在每个block之前,残差连接在每个block后
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E pos , E ∈ R ( P 2 ⋅ C ) × D , E pos ∈ R ( N + 1 ) × D z ℓ ′ = MSA ( L N ( z ℓ − 1 ) ) + z ℓ − 1 , ℓ = 1 … L z ℓ = M L P ( LN ( z ℓ ′ ) ) + z ℓ ′ , ℓ = 1 … L y = L N ( z L 0 ) \begin{aligned} \mathbf{z}_{0} &=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D} \\ \mathbf{z}_{\ell}^{\prime} &=\operatorname{MSA}\left(\mathrm{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & \ell=1 \ldots L \\ \mathbf{z}_{\ell} &=\mathrm{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & \ell=1 \ldots L \\ \mathbf{y} &=\mathrm{LN}\left(\mathbf{z}_{L}^{0}\right) & \end{aligned} z0zℓ′zℓy=[xclass ;xp1E;xp2E;⋯;xpNE]+Epos ,=MSA(LN(zℓ−1))+zℓ−1,=MLP(LN(zℓ′))+zℓ′,=LN(zL0)E∈R(P2⋅C)×D,Epos ∈R(N+1)×Dℓ=1…Lℓ=1…L
主要架构基本就是这样了,没有太多可细说的。不过最近在看它的代码,下面这篇写得很好,可以 code from scratch,有兴趣的可以一看:https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632