前言
VIT模型即vision transformer,其想法是将在NLP领域的基于自注意力机制transformer模型用于图像任务中,相比于图像任务中的传统的基于卷积神经网络模型,VIT模型在大数据集上有着比卷积网络更强的效果和更节约的成本。
transformer
transformer模型是用于自然语言处理的一个基于注意力机制的模型,其图如下所示,该模型主要由解码器和编码器两部分组成。在nlp相关任务中,处理的数据对象主要是句子或句子对,因此,在训练之前,存在一个由多个token组成的字典。而输入模型的数据为形如大小为NF的向量,其中N为tokens的数量,F为表示每个token语义信息的向量长度,然后通过线性变化,加入位置信息得到大小为ND的向量,其中D为论文规定的输入给注意力层的向量大小。
注意力机制
从人认识句子的模式考虑,面对一个句子的多个单词时,我们对于不同的词组的关注度自然也存在不同,基于这个想法提出的注意力机制的抽象模型如下:query表示待处理目标,key-value表示键值对,输出attention本质上为values的加权和,而这里的权重即为注意力系数,其计算公式如下:其中,Q K分别表示query和健值的向量矩阵,dk为二者的大小。另外,还可将QK分为多个子矩阵的拼接,分别计算注意力,最后将结果拼接回去原来的大小。这种方法称为多头注意力机制
从transformer到VIT
若想将transformer模型用于对于二维图像的处理,首先需要解决的问题即是如和将二维图像转化为可输入给transformer模型的1维向量,自然想到将大小维NN的图像分为pp的小图像patch,再将每个patch展开这样得到大小维度为(NN/PP)(PP*3)的向量,3表示rgb三通道,再将该向量经过一个线性变化使其特征维度变为D,即可继续输入给transformer进行训练。其模型结构图如下:
如图所示,在做图像分类任务时,需要增加一个表示类别的,token,最后在加上位置编码信息,得到的向量作为最终的transformer的输入。另外,在vit模型中,QKV都是同样来自图像patch的同样大小的三个向量。
整个流程用公式表示如下所示其中第一步即图像的预处理,包扩图像分块,增加类别信息,位置信息,E表示将表示图像信息的向量通过线性变化进行维度转化;第二个式子为MSA部分,包括多头自注意力、跳跃连接 (Add) 和层规范化 (Norm) 三个部分,可以重复L个MSA block;第三个式子为MLP部分,包括前馈网络 (FFN)、跳跃连接 (Add) 和层规范化 (Norm) 三个部分。
代码实现
pytorch中可直接调用搭建vit模型,相关代码如下所示:
```python
import torch
import numpy as np
from vit_pytorch import ViT
import torchvision
from torchsummary import summary
#创建VIt模型实例
v=ViT(
image_size=256, #原始图像大小 256*256
patch_size=32, #图像块的大小,即将原始图像按块大小切割
num_classes=10, #分类数量
dim=1024, #transformer隐藏变量维度。即输入给transform模型的特征维度
depth=6, #transform编码器层数
heads=6, #msa中多头注意力机制的头数
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1)
如输入一个图像,能得到一个分类结果。
```python
img=torch.randn(1,3,256,256) #batch_size*C*h*w
preds=v(img)
preds.size()
##从头搭建vit模型
通过上面的模型原理介绍,VIT模型其实是以transformer为基础的,因此需要先搭建ffn,注意力机制等组件,再将其与图像预处理,编码嵌入层等拼接起来得到一个完整的vit模型
import torch
from torch import nn , einsum
import torch.nn.functional as F
from einops import rearrange , repeat #einops是一个处理张量的第三方库 output_tensor = rearrange(input_tensor, 't b c -> b c t')
from einops.layers.torch import Rearrange # 沿着某一维复制 output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)
# Rearrange('b c h w -> b (c h w)'),
def pair(t):