CSDN社区:
Vision Transformer(ViT)简介
Transformer最初提出是针对NLP领域的,并且在NLP领域大获成功。受到其启发,尝试将Transformer应用到CV领域。通过实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。
模型结构
ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:
简单而言,模型由三个模块组成:
Linear Projection of Flattened Patches(Embedding层)
Transformer Encoder(图右侧有给出更加详细的结构)
MLP Head(最终用于分类的层结构)
整体流程图如下所示:
ViT模型的输入
传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。
在ViT模型中:
-
通过将输入图像在每个channel上划分为16*16个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。
-
再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。
输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。
-
class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
-
增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
-
pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。
-
由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。
实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。
总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。
Transformer Encoder详解
Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由以下几部分组成:
Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理.
Multi-Head Attention.
Dropout/DropPath
MLP Block,就是全连接+GELU激活函数+Dropout. 组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。是Transformer的核心。
多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。
总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。
所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。
实际训练结果如下:
从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。
模型推理
在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。
本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。
在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
transforms.Decode(),
transforms.Resize([224, 224]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
input_columns=["image"],
num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):
"""dedine enum color."""
red = (0, 0, 255)
green = (0, 255, 0)
blue = (255, 0, 0)
cyan = (255, 255, 0)
yellow = (0, 255, 255)
magenta = (255, 0, 255)
white = (255, 255, 255)
black = (0, 0, 0)
def check_file_exist(file_name: str):
"""check_file_exist."""
if not os.path.isfile(file_name):
raise FileNotFoundError(f"File `{file_name}` does not exist.")
def color_val(color):
"""color_val."""
if isinstance(color, str):
return Color[color].value
if isinstance(color, Color):
return color.value
if isinstance(color, tuple):
assert len(color) == 3
for channel in color:
assert 0 <= channel <= 255
return color
if isinstance(color, int):
assert 0 <= color <= 255
return color, color, color
if isinstance(color, np.ndarray):
assert color.ndim == 1 and color.size == 3
assert np.all((color >= 0) & (color <= 255))
color = color.astype(np.uint8)
return tuple(color)
raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):
"""imread."""
if isinstance(image, pathlib.Path):
image = str(image)
if isinstance(image, np.ndarray):
pass
elif isinstance(image, str):
check_file_exist(image)
image = Image.open(image)
if mode:
image = np.array(image.convert(mode))
else:
raise TypeError("Image must be a `ndarray`, `str` or Path object.")
return image
def imwrite(image, image_path, auto_mkdir=True):
"""imwrite."""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(image_path))
if dir_name != '':
dir_name = os.path.expanduser(dir_name)
os.makedirs(dir_name, mode=777, exist_ok=True)
image = Image.fromarray(image)
image.save(image_path)
def imshow(img, win_name='', wait_time=0):
"""imshow"""
cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
# if user closed window or if some key pressed
if closed or ret != -1:
break
else:
ret = cv2.waitKey(wait_time)
def show_result(img: str,
result: Dict[int, float],
text_color: str = 'green',
font_scale: float = 0.5,
row_width: int = 20,
show: bool = False,
win_name: str = '',
wait_time: int = 0,
out_file: Optional[str] = None) -> None:
"""Mark the prediction results on the picture."""
img = imread(img, mode="RGB")
img = img.copy()
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
if out_file:
show = False
imwrite(img, out_file)
if show:
imshow(img, win_name, wait_time)
def index2label():
"""Dictionary output for image numbers and categories of the ImageNet dataset."""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {int(label): mapping[int(label)]}
print(output)
show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
result=output,
out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
输出的结果会带标签,如下图。