Vision Transformer详解(附代码)

1 引言

T r a n s f o r m e r \mathrm{Transformer} Transformer N L P \mathrm{NLP} NLP中大获成功, V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer则将 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而 T r a n s f o r m e r \mathrm{Transformer} Transformer中的注意力机制可以综合考量全局的特征信息。 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer尽力做到在不改变 T r a n s f o r m e r \mathrm{Transformer} Transformer E n c o d e r \mathrm{Encoder} Encoder架构的前提下,直接将其从 N L P \mathrm{NLP} NLP领域迁移到计算机视觉领域中,目的是让原始的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型开箱即用。如果想要了解 T r a n s f o r m e r \mathrm{Transformer} Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》

2 注意力机制应用

在正式详细介绍 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中 S A G A N \mathrm{SAGAN} SAGAN A t t n G A N \mathrm{AttnGAN} AttnGAN就早已经在 G A N \mathrm{GAN} GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

S A G A N \mathrm{SAGAN} SAGAN G A N \mathrm{GAN} GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。 S A G A N \mathrm{SAGAN} SAGAN中自注意力机制的操作原理如下图所示。
给定一个 3 3 3通道的输入特征图 X = ( X 1 , X 2 , X 3 ) ∈ R 3 × 3 × 3 X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3} X=(X1,X2,X3)R3×3×3,其中 X i ∈ R 3 × 3 X^{i}\in \mathbb{R}^{3\times 3} XiR3×3 i ∈ { 1 , 2 , 3 } i\in\{1,2,3\} i{1,2,3}。将 X X X分别输入到三个不同的 1 × 1 1\times 1 1×1的卷积层中,并生成 q u e r y \mathrm{query} query特征图 Q ∈ R 3 × 3 × 3 Q\in \mathbb{R}^{3\times 3\times 3} QR3×3×3 k e y \mathrm{key} key特征图 K ∈ R 3 × 3 × 3 K\in \mathbb{R}^{3\times 3\times 3} KR3×3×3 v a l u e \mathrm{value} value特征图 V ∈ R 3 × 3 × 3 V\in \mathbb{R}^{3\times 3\times 3} VR3×3×3。生成 Q Q Q具体的计算过程为,给定三个卷积核 W q 1 W^{q1} Wq1 W q 2 W^{q2} Wq2 W q 3 ∈ R 1 × 1 × 3 W^{q3}\in\mathbb{R}^{1\times1\times3} Wq3R1×1×3,并用这三个卷积核分别与 X X X做卷积运算得到 Q 1 Q^1 Q1 Q 2 Q^2 Q2 Q 3 ∈ R 3 × 3 Q^3\in \mathbb{R}^{3 \times 3} Q3R3×3,即 { Q 1 = X ∗ W q 1 Q 2 = X ∗ W q 2 Q 3 = X ∗ W q 3 \left\{\begin{aligned}Q^1&=X * W^{q1}\\Q^2&=X * W^{q2}\\Q^3&=X*W^{q3}\end{aligned}\right. Q1Q2Q3=XWq1=XWq2=XWq3其中 ∗ * 表示卷积运算符号。同理生成 K K K V V V的计算过程与 Q Q Q的计算过程类似。然后再利用 Q Q Q K K K进行注意力分数的计算得到矩阵 A ∈ R 3 × 3 A\in \mathbb{R}^{3 \times 3} AR3×3,其中矩阵 A A A的元素 a m l a_{ml} aml的计算公式为 a m l = Q m ∗ K l , m ∈ { 1 , 2 , 3 } , l ∈ { 1 , 2 , 3 } a_{ml}=Q^m * K^l,\quad m \in \{1,2,3\},l\in \{1,2,3\} aml=QmKl,m{1,2,3},l{1,2,3}再对矩阵 A A A利用 s o f t m a x \mathrm{softmax} softmax函数进行注意力分布的计算得到注意力分布矩阵 S ∈ R 3 × 3 S\in \mathbb{R}^{3\times 3} SR3×3,其中矩阵 S S S的元素 s m l s_{ml} sml的计算公式为 s m l = exp ⁡ ( a m l ) ∑ i = j 3 exp ⁡ ( a m j ) , m ∈ { 1 , 2 , 3 } , l ∈ { 1 , 2 , 3 } s_{ml}=\frac{\exp(a_{ml})}{\sum\limits_{i=j}^{3}\exp(a_{mj})},\quad m \in \{1,2,3\},l\in\{1,2,3\} sml=i=j3exp(amj)exp(aml),m{1,2,3},l{1,2,3}最后利用注意力分布矩阵 S S S v a l u e \mathrm{value} value特征图 V V V得到最后的输出 O = ( O 1 , O 2 , O 3 ) ∈ R 3 × 3 × 3 O=(O^1,O^2,O^3)\in \mathbb{R}^{3\times 3\times 3} O=(O1,O2,O3)R3×3×3,即 { O 1 = s 11 ⋅ V 1 + s 12 ⋅ V 2 + s 13 ⋅ V 3 O 2 = s 21 ⋅ V 1 + s 22 ⋅ V 2 + s 23 ⋅ V 3 O 3 = s 31 ⋅ V 1 + s 32 ⋅ V 2 + s 33 ⋅ V 3 \left\{\begin{aligned}O^1&=s_{11}\cdot V^1+s_{12}\cdot V^2+s_{13}\cdot V^3\\O^2&=s_{21}\cdot V^1+s_{22}\cdot V^2+s_{23}\cdot V^3\\O^3&=s_{31}\cdot V^1+s_{32}\cdot V^2+s_{33}\cdot V^3\end{aligned}\right. O1O2O3=s11V1+s12V2+s13V3=s21V1+s22V2+s23V3=s31V1+s32V2+s33V3

2.2 AttnGAN

A t t n G A N \mathrm{AttnGAN} AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。 A t t n G A N \mathrm{AttnGAN} AttnGAN中注意力机制的操作原理如下图所示。
 给定输入图像特征向量 h = ( h 1 , h 2 , h 3 , h 4 ) ∈ R D ^ × 4 h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4} h=(h1,h2,h3,h4)RD^×4和词特征向量 e = ( e 1 , e 2 , e 3 , e 4 ) e=(e^1,e^2,e^3,e^4) e=(e1,e2,e3,e4),其中 h i ∈ R D ^ × 1 h^i\in \mathbb{R}^{\hat{D}\times 1} hiRD^×1 e i ∈ R D × 1 e^i\in \mathbb{R}^{D\times 1} eiRD×1 i ∈ { 1 , 2 , 3 , 4 } i\in \{1,2,3,4\} i{1,2,3,4}。首先利用矩阵 W W W进行线性变换将词特征空间 R D \mathbb{R}^{D} RD的向量转换成图像特征空间 R D ^ \mathbb{R}^{\hat{D}} RD^的向量,则有 e ^ = W ⋅ e = ( e ^ 1 , e ^ 2 , e ^ 3 , e ^ 4 ) ∈ R D ^ × 4 \hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4} e^=We=(e^1,e^2,e^3,e^4)RD^×4然后再利用转换后的词特征 e ^ \hat{e} e^与图像特征 h h h进行注意力分数的计算得到注意力分数矩阵 S S S,其中的分量 s i j s_{ij} sij的计算公式为 s i j = ( h i ) ⊤ ⋅ e ^ j , i ∈ { 1 , 2 , 3 , 4 } , j ∈ { 1 , 2 , 3 , 4 } s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\} sij=(hi)e^j,i{1,2,3,4},j{1,2,3,4} 再对矩阵 S S S利用 s o f t m a x \mathrm{softmax} softmax函数进行注意力分布的计算得到注意力分布矩阵 β ∈ R 4 × 4 \beta\in \mathbb{R}^{4\times 4} βR4×4,其中矩阵 β \beta β的元素 β i j \beta_{ij} βij的计算公式为 β i j = exp ⁡ ( s i j ) ∑ k = 1 3 exp ⁡ ( s i k ) , i ∈ { 1 , 2 , 3 , 4 } , l ∈ { 1 , 2 , 3 , 4 } \beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\} βij=k=13exp(sik)exp(sij),i{1,2,3,4},l{1,2,3,4}最后利用注意力分布矩阵 β \beta β和图像特征 h h h得到最后的输出 o = ( o 1 , o 2 , o 3 , o 4 ) ∈ R D ^ × 4 o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4} o=(o1,o2,o3,o4)RD^×4,即 { o 1 = β 11 ⋅ h 1 + β 12 ⋅ h 2 + β 13 ⋅ h 3 + β 14 ⋅ h 4 o 2 = β 21 ⋅ h 1 + β 22 ⋅ h 2 + β 23 ⋅ h 3 + β 24 ⋅ h 4 o 3 = β 31 ⋅ h 1 + β 32 ⋅ h 2 + β 33 ⋅ h 3 + β 34 ⋅ h 4 o 4 = β 41 ⋅ h 1 + β 42 ⋅ h 2 + β 43 ⋅ h 3 + β 44 ⋅ h 4 \left\{\begin{aligned}o^1&=\beta_{11}\cdot h^1+\beta_{12}\cdot h^2+\beta_{13}\cdot h^3+\beta_{14}\cdot h^4\\o^2&=\beta_{21}\cdot h^1+\beta_{22}\cdot h^2+\beta_{23}\cdot h^3+\beta_{24}\cdot h^4\\o^3&=\beta_{31}\cdot h^1+\beta_{32}\cdot h^2+\beta_{33}\cdot h^3+\beta_{34}\cdot h^4\\o^4&=\beta_{41}\cdot h^1+\beta_{42}\cdot h^2+\beta_{43}\cdot h^3+\beta_{44}\cdot h^4\end{aligned}\right. o1o2o3o4=β11h1+β12h2+β13h3+β14h4=β21h1+β22h2+β23h3+β24h4=β31h1+β32h2+β33h3+β34h4=β41h1+β42h2+β43h3+β44h4

3 Vision Transformer

本节主要详细介绍 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的工作原理,3.1节是关于 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架,3.2节是关于 T r a n s f o r m e r   E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder的内部操作细节。对于 T r a n s f o r m e r   E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的 T r a n s f o r m e r \mathrm{Transformer} Transformer,还是计算机视觉中图像生成的 S A G A N \mathrm{SAGAN} SAGAN,以及文本生成图像的 A t t n G A N \mathrm{AttnGAN} AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer整体框架

如果下图所示为 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的整体框架以及相应的训练流程

  • 给定一张图片 X ∈ R 3 n × 3 n X\in \mathbb{R}^{3n\times 3n} XR3n×3n,并将它分割成 9 9 9 p a t c h \mathrm{patch} patch分别为 x 1 , ⋯   , x 9 ∈ R n × n x^1,\cdots,x^9\in\mathbb{R}^{n\times n} x1,,x9Rn×n。然后再将这个 9 9 9 p a t c h \mathrm{patch} patch拉平,则有 x 1 , ⋯   , x 9 ∈ R n 2 x^1,\cdots,x^9\in\mathbb{R}^{n^2} x1,,x9Rn2
  • 利用矩阵 W ∈ R l × n 2 W\in \mathbb{R}^{l \times n^2} WRl×n2将拉平后的向量 x i ∈ R n 2 , i ∈ { 1 , ⋯   , 9 } x^i\in\mathbb{R}^{n^2},i\in\{1,\cdots,9\} xiRn2,i{1,,9}经过线性变换得到图像编码向量 z i ∈ R l , i ∈ { 1 , ⋯   , 9 } z^i\in \mathbb{R}^{l},i\in\{1,\cdots,9\} ziRl,i{1,,9},具体的计算公式为 z i = W ⋅ x i , i ∈ { 1 , ⋯ 9 } z^i = W\cdot x^i,\quad i\in\{1,\cdots9\} zi=Wxi,i{1,9}
  • 然后将图像编码向量 z i , i ∈ { 1 , ⋅ , 9 } z^{i},i\in\{1,\cdot,9\} zi,i{1,,9}和类编码向量 z 0 z^0 z0分别与对应的位置编进行加和得到输入编码向量,则有 z i + p i ∈ R l , i ∈ { 0 , ⋯ 9 } z^{i}+p^{i}\in\mathbb{R}^l,\quad i\in\{0,\cdots 9\} zi+piRl,i{0,9}
  • 接着将输入编码向量输入到 V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder中得到对应的输出 o i ∈ R l , i ∈ { 0 , ⋯   , 9 } o^i\in \mathbb{R}^l,i\in\{0,\cdots,9\} oiRl,i{0,,9}
  • 最后将类编码向量 o 0 o^0 o0输入全连接神经网络中 M L P \mathrm{MLP} MLP得到类别预测向量 y ^ ∈ R c \hat{y}\in\mathbb{R}^c y^Rc,并与真实类别向量 y ∈ R c y\in\mathbb{R}^c yRc计算交叉熵损失得到损失值 l o s s loss loss,利用优化算法更新模型的权重参数

注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量 o 0 o^0 o0 V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder其它的输出为什么没有输入到 M L P \mathrm{MLP} MLP中?为了回答这个问题,我们令函数 f 0 ( ⋅ ) f_0(\cdot) f0() V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder,则类编码向量 o 0 o^{0} o0可以表示为 o 0 = f 0 ( z 0 + p 0 , ⋯   , z 9 + p 9 ) o^0=f_0(z^0+p^0,\cdots,z^9+p^9) o0=f0(z0+p0,,z9+p9)由上公式可以发现,类编码向量 o 0 o^{0} o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

3.2 Transformer Encoder操作原理

如下图所示分别为 V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder模型结构图和原始 T r a n s f o r m e r   E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder的模型结构图。可以直观的发现 V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder T r a n s f o r m e r   E n c o d e r \mathrm{Transformer\text{ }Encoder} Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的   T r a n s f o r m e r \mathrm{ \text{ }Transformer}  Transformer代码实例中,将以下两种 E n c o d e r \mathrm{Encoder} Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。
下图左半部分 V i s i o n   T r a n s f o r m e r   E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder具体的操作流程为

  • 给定输入编码矩阵 Z ∈ R l × n Z\in\mathbb{R}^{l\times n} ZRl×n,首先将其进行层归一化得到 Z ′ ∈ R l × n Z^{\prime}\in\mathbb{R}^{l \times n} ZRl×n
  • 利用矩阵 W q , W k , W v ∈ R l × l W^{q},W^{k},W^{v}\in \mathbb{R}^{l\times l} Wq,Wk,WvRl×l Z ′ Z^{\prime} Z进行线性变换得到矩阵 Q , K , W ∈ R l × n Q,K,W\in\mathbb{R}^{l\times n} Q,K,WRl×n具体的计算过程为 { Q = W q ⋅ Z ′ K = W k ⋅ Z ′ V = W v ⋅ Z ′ \left\{\begin{aligned}Q &= W^{q}\cdot Z^{\prime}\\K&=W^{k}\cdot Z^{\prime}\\V&=W^v \cdot Z^{\prime}\end{aligned}\right. QKV=WqZ=WkZ=WvZ再将这三个矩阵输入到 M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention(该原理参考《Transformer详解(附代码)》)中得到矩阵 Z ′ ′ ∈ R l × n Z^{\prime\prime}\in \mathbb{R}^{l \times n} ZRl×n将最原始的输入矩阵 Z Z Z Z ′ ′ Z^{\prime\prime} Z进行残差计算得到 Z + Z ′ ′ ∈ R l × n Z+Z^{\prime\prime}\in \mathbb{R}^{l\times n} Z+ZRl×n
  • Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z进行第二次层归一化得到 Z ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime}\in\mathbb{R}^{l\times n} ZRl×n,然后再将 Z ′ ′ ′ Z^{\prime\prime\prime} Z输入到全连接神经网络中进行线性变换得到 Z ′ ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} ZRl×n。最后将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z Z ′ ′ ′ ′ Z^{\prime\prime\prime\prime} Z进行残差操作得到该 B l o c k \mathrm{Block} Block的输出 Z + Z ′ ′ + Z ′ ′ ′ ′ ∈ R l × n Z+Z^{\prime\prime}+Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z+Z+ZRl×n。一个 E n c o d e r \mathrm{Encoder} Encoder可以将 N N N B l o c k \mathrm{Block} Block进行堆叠,最后得到的输出为 O ∈ R l × n O\in\mathbb{R}^{l\times n} ORl×n

4 程序代码

V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的作者的本意就是想让在 N L P \mathrm{NLP} NLP中的 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构做尽可能少的修改可以直接迁移到 C V \mathrm{CV} CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种 E n c o d e r \mathrm{Encoder} Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的 E n c o d e r \mathrm{Encoder} Encoder网络结构,一种是 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer论文里的 E n c o d e r \mathrm{Encoder} Encoder的网络结构。这里需要注意的是, V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer里并能没有 D e c o d e r \mathrm{Decoder} Decoder模块,所以不需要计算 E n c o d e r \mathrm{Encoder} Encoder D e c o d e r \mathrm{Decoder} Decoder的交叉注意力分布,这就进一步给 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的编程带来了简便。 V i s i o n   T r a n s f o r m e r \mathrm{Vision\text{ }Transformer} Vision Transformer的开源代码的网址为https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch

import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange

def inputs_deal(inputs):
	return inputs if isinstance(inputs, tuple) else(inputs, inputs)

class SelfAttention(nn.Module):
	def __init__(self, embed_size, heads):
		super(SelfAttention, self).__init__()
		self.embed_size = embed_size
		self.heads = heads
		self.head_dim = embed_size // heads

		assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

		self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

	def forward(self, values, keys, query):
		N =query.shape[0]
		value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

		# split embedding into self.heads pieces
		values = values.reshape(N, value_len, self.heads, self.head_dim)
		keys = keys.reshape(N, key_len, self.heads, self.head_dim)
		queries = query.reshape(N, query_len, self.heads, self.head_dim)
		
		values = self.values(values)
		keys = self.keys(keys)
		queries = self.queries(queries)

		energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
		# queries shape: (N, query_len, heads, heads_dim)
		# keys shape : (N, key_len, heads, heads_dim)
		# energy shape: (N, heads, query_len, key_len)

		attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

		out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
		# attention shape: (N, heads, query_len, key_len)
		# values shape: (N, value_len, heads, heads_dim)
		# (N, query_len, heads, head_dim)

		out = self.fc_out(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, embed_size, heads, dropout, forward_expansion):
		super(TransformerBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm = nn.LayerNorm(embed_size)

		self.feed_forward = nn.Sequential(
			nn.Linear(embed_size, forward_expansion*embed_size),
			nn.ReLU(),
			nn.Linear(forward_expansion*embed_size, embed_size)
		)
		self.dropout = nn.Dropout(dropout)


	def forward(self, value, key, query, x, type_mode):
		if type_mode == 'original':
			attention = self.attention(value, key, query)
			x = self.dropout(self.norm(attention + x))
			forward = self.feed_forward(x)
			out = self.dropout(self.norm(forward + x))
			return out
		else:
			attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
			x =self.dropout(attention + x)
			forward = self.feed_forward(self.norm(x))
			out = self.dropout(forward + x)
			return out

class TransformerEncoder(nn.Module):
	def __init__(
			self,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout = 0,
			type_mode = 'original'
		):
		super(TransformerEncoder, self).__init__()
		self.embed_size = embed_size
		self.type_mode = type_mode
		self.Query_Key_Value = nn.Linear(embed_size, embed_size * 3, bias = False)

		self.layers = nn.ModuleList(
			[
				TransformerBlock(
					embed_size,
					heads,
					dropout=dropout,
					forward_expansion=forward_expansion,
					)
				for _ in range(num_layers)]
		)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x):
		for layer in self.layers:
			QKV_list = self.Query_Key_Value(x).chunk(3, dim = -1)
			x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)
		return x

class VisionTransformer(nn.Module):
	def __init__(self, 
				image_size, 
				patch_size, 
				num_classes, 
				embed_size, 
				num_layers, 
				heads, 
				mlp_dim, 
				pool = 'cls',
				channels = 3,
				dropout = 0,
				emb_dropout = 0.1,
				type_mode = 'vit'):
		super(VisionTransformer, self).__init__()
		img_h, img_w = inputs_deal(image_size)
		patch_h, patch_w = inputs_deal(patch_size)
		
		assert img_h % patch_h == 0 and img_w % patch_w == 0, 'Img dimensions can be divisible by the patch dimensions'

		num_patches = (img_h // patch_h) * (img_w // patch_w)

		patch_size = channels * patch_h * patch_w

		self.patch_embedding = nn.Sequential(
			Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
			nn.Linear(patch_size, embed_size, bias=False)
		)


		self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_size))
		self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
		self.dropout = nn.Dropout(emb_dropout)



		self.transformer = TransformerEncoder(embed_size, 
									num_layers, 
									heads, 
									mlp_dim,
									dropout)
		self.pool = pool
		self.to_latent = nn.Identity()

		self.mlp_head = nn.Sequential(
			nn.LayerNorm(embed_size),
			nn.Linear(embed_size, num_classes)
		)

	def forward(self, img):
		x = self.patch_embedding(img)
		b, n, _ = x.shape
		cls_tokens = repeat(self.cls_token, '() n d ->b n d', b = b)
		x = torch.cat((cls_tokens, x), dim = 1)
		x += self.pos_embedding[:, :(n + 1)]
		x = self.dropout(x)
		x = self.transformer(x)
		x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
		x = self.to_latent(x)
		return self.mlp_head(x)


if __name__ == '__main__':
	vit = VisionTransformer(
			image_size = 256,
			patch_size = 16,
			num_classes = 10,
			embed_size = 256,
			num_layers = 6,
			heads = 8,
			mlp_dim = 512,
			dropout = 0.1,
			emb_dropout = 0.1
		)
	img = torch.randn(3, 3, 256, 256)
	pred = vit(img)
	print(pred)

以下代码是利用 V i s i o n   T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer网络结构训练一个分类 m n i s t \mathrm{mnist} mnist数据集的主程序代码。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os

        
def train():
    batch_size = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epoches = 20
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
    mnist_model = VIT.VisionTransformer(
    	image_size = 28,
    	patch_size = 7,
    	num_classes = 10,
    	channels = 1,
    	embed_size = 512,
    	num_layers = 1,
    	heads = 2,
    	mlp_dim =1024,
    	dropout = 0,
    	emb_dropout = 0)
    loss_fn = nn.CrossEntropyLoss()
    mnist_model = mnist_model.to(device)
    opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
    mnist_model.train()
    for epoch in range(epoches):
    	total_loss = 0 
    	corrects = 0 
    	num = 0
    	for batch_X, batch_Y in train_loader:
    		batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
    		opitimizer.zero_grad()
    		outputs = mnist_model(batch_X)
    		_, pred = torch.max(outputs.data, 1)
    		loss = loss_fn(outputs, batch_Y)
    		loss.backward()
    		opitimizer.step()
    		total_loss += loss.item()
    		corrects = torch.sum(pred == batch_Y.data)
    		num += batch_size
    		print(epoch, total_loss/float(num), corrects.item()/float(batch_size))

if __name__ == '__main__':
	train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个 V i s i o n   T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer模型真的是很烧硬件,跟训练一个普通的 C N N \mathrm{CNN} CNN模型相比,训练一个 V i s i o n   T r a n s f o r m e r \mathrm{Vision \text{ }Transformer} Vision Transformer模型更加耗时耗力。

### 回答1: Vision Transformer(视觉Transformer)是一种基于Transformer架构的深度学习模型,主要用于图像分类和目标检测任务。与传统卷积神经网络不同,Vision Transformer使用了全局自注意力机制,使得模型可以更好地捕捉到不同位置之间的关系。Vision Transformer已经在ImageNet等大规模数据集上取得了优秀的性能表现,并逐渐成为深度学习领域的研究热点。 ### 回答2: Vision Transformer(ViT)是一种全新的视觉识别模型,由谷歌提出,它借鉴了自然语言处理领域中的transformer技术。ViT在图像分类、目标检测和分割等视觉任务中均有较好的效果,并且在一些任务中超越了传统的卷积神经网络(CNN)模型。 ViT模型的核心是transformer encodertransformer decoder两大部分。transformer encoder负责将输入序列转换成特征向量,而transformer decoder则是根据特征向量生成目标输出序列。在ViT模型中,将一张图片视为一个大小为H×W的序列,然后再通过一些处理,将会得到一个大小为N的特征向量,其中每个元素代表了原图中的一个位置坐标。 ViT模型通过将一张图像划分成大小为P × P的图块,然后将每个图块拼接成一个序列,来处理整个图像。基于这样的序列表示方式,ViT将应用transformer架构对序列进行处理,以生成特征表示。同时,由于传统的transformer是为自然语言处理领域设计的,所以需要对其进行一定的调整,才能适用于图像处理任务。因此,ViT引入了一个叫做“patch embedding”的操作,它将每个P × P的图块映射成一个特征向量,然后再进行transformer处理。 除此之外,在ViT模型中还使用了一些其他的技术来提升模型的性能,包括:1)将transformer encoder中的自注意力替换为多头注意力,以增加模型的并行性和泛化性;2)在每个transformer block中应用残差连接和归一化,以加速训练、提高稳定性和精度;3)引入了一个分类头,用于将特征向量转换为最终的输出类别概率。这些技术的应用均使得ViT模型在视觉识别任务上表现出了很好的效果。 总之,ViT模型是一种基于transformer架构的新型视觉识别模型,它采用多头注意力、残差连接等技术,将图像视为序列,利用transformer encodertransformer decoder对序列进行处理,并最终输出目标类别概率。相比于传统的CNN模型,在一些任务中ViT具有更优秀的表现,有望成为未来视觉处理领域的新趋势。 ### 回答3: Vision Transformer(ViT)是谷歌的一款新型视觉模型,与传统的卷积神经网络(CNN)不同,ViT是由注意力机制(Attention Mechanism)组成的纯粹Transformer模型。Transformer源于自然语言处理领域,它能解决文本序列问题,但对于图像数据来说,采用Transformer是一个全新的尝试。 ViT将图像分割成固定数量的块(例如16*16),每个块被视为一个向量。这些向量然后传递给Transformer编码器,其中包括多层自注意力机制。通过学习这些向量之间的相互作用,模型能够提取出各个块之间的关键信息。最后,全连接层通过分类器将最终向量映射到相应的类别。 相较于传统CNN,ViT的明显优势是无需人工设计的特征提取器,这使得模型更具通用性,适用于各种视觉任务,并且能够处理多种分辨率和大小的图像。同时,attention机制带来的优点也让ViT在处理长时间序列信息时表现突出。 然而ViT在使用时还存在一些挑战。由于图像信息需要被分割成固定大小的块,因此对于具有细长结构的对象(如马路、河流等),模型很容易将它们拆分为多个块,导致信息的丢失。此外,由于向量长度的限制,ViT的处理能力存在局限性。 在处理大规模数据时,ViT还需要面对计算资源的挑战。为解决这一问题,研究人员提出了一系列改进算法,如DeiT、T2T-ViT、Swin Transformer等,它们能够更好地处理大规模图像数据。 总的来说,Vision Transformer模型是一种全新的尝试,它使用自注意力机制构建纯Transformer模型来处理图像数据。虽然存在一些性能挑战,但随着技术的不断进步和改进算法的诞生,ViT模型必将成为图像处理领域的重要一员。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值