一、由头
在这里我们开一个新坑,这个新坑主要用来讲解Vision Transformer的详细实现细节。
关于Vision Transformer我们预计开四篇文章把它讲透,每篇文章我们都会代码和输入的样例。
关于Vision Transformer的第一篇文章,我们就先实现一下Vision Transformer的第一步——如果把图片变成patches然后输入到Transformer的Encoder当中。
图片来源于Vit原文
首先我们再回顾一下Vit的原文的标题:
图片来源于Vit原文
我们这里把红框当中的英文着重翻译一下,译为:一张图片可以等价为形状为16×16的单词。
这句话初看有点抽象(或者说我翻译的有点抽象),但实际上过程是这样的:
在Vit当中,每张输入图片的大小都被resize为224×224,此时我们用一个kernel_size为14,stride为14的卷积核去对该图进行操作,卷积完之后,Feature_map的形状就变为16×16×channel_nums,此时的16×16就和标题呼应上了。
接下来,我们就用代码来对应红框中的每一步操作:
二、分步代码
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
img = Image.open('chart_example_1.png').convert('RGB') #随便读取一个图片
img = img.resize((224, 224)) #转换成224 * 224的大小
img_tensor = transforms.ToTensor()(img) #转换成tensor类型 img_tensor = img_tensor.unsqueeze(0) #这里我们在前面增加一个维度表示batch_size
print(img_tensor.shape)
#
torch.Size([1, 3, 224, 224])
上述代码用来示例图片。我们将加载完毕的图片打印出来后可以看到形状为1×3×224×224,其中1是指批量大小,3是指图片的RGB通道数,224×224代表图片的长宽。
## 将图片转换成
patch in_channels = img_tensor.shape[1] #提取图片的输入通道
embed_dim = 768 #设置图片的输出通道,即embedding_dim
patch_size = 16 #设置patch_size,对应的是卷积核的大小和卷积核stride
num_patches = (224 // patch_size) * (224 // patch_size) #这里算出卷积之后的Feature_map形状 patch_conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) #根据参数设置对应的卷积核
patch_tensor = patch_conv(img_tensor) #得到卷积之后的图片 print(patch_tensor.shape) #打印patch的形状 patch_tensor = patch_tensor.flatten(2) #将Feature_map的长和宽拉平,当成token patch_tensor = patch_tensor.permute(0, 2, 1) #将embedding维度和token维度互换位置 print(patch_tensor.shape) #打印形状
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) #设计cls_token,embedding_dim和patch的embedding保持一致
pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim)) #pos_embedding,embed_dim和patch的embedding保持一致,并且第二个维度上的大小为num_patches + 1 #这里+1是因为我们的patch会加上一个cls_token
pos_drop = nn.Dropout(0.1) #设置一个drop_out
patch_tensor = torch.cat((cls_token, patch_tensor), dim=1) #把cls_token和patch_tensor在第二个维度上拼接起来
patch_tensor += pos_embed #patch_tensor和pos_embeding相加
patch_tensor = pos_drop(patch_tensor)
print(patch_tensor.shape)
#
torch.Size([1, 768, 14, 14])
torch.Size([1, 196, 768])
torch.Size([1, 197, 768])
三、集成代码(写成class)
上述代码是我们一步一步去实现的patch操作,显得非常不规范。下面我们将这些代码进行集成,写成一个class,提高代码的复用性和观赏性。
import torch.nn as nn
import torch
class patch_embedding(nn.Module):
def __init__(self, in_channels, embed_dim, num_patches, patch_size):
super(patch_embedding, self).__init__()
self.patcher = nn.Sequential( nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), nn.Flatten(2) )
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
self.pos_drop = nn.Dropout(0.1)
def forward(self, x):
x = self.patcher(x).permute(0, 2, 1)
print(x.shape)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
print(cls_tokens.shape)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x + self.pos_embed) return x
这里我们把刚才所有的torch函数都集成在一个class当中,大家可以对一一对应一下代码。
import torch
import numpy as np
import torch.nn as nn
from PIL import Image
from torchvision import transforms
img = Image.open('chart_example_1.png').convert('RGB')
img = img.resize((224, 224)) #转换成224 * 224的大小
img_tensor = transforms.ToTensor()(img) #转换成tensor类型 img_tensor = img_tensor.unsqueeze(0)
print(img_tensor.shape)
x = patch_embedding(in_channels = 3, embed_dim = 768, num_patches = 196, patch_size = 16)(img_tensor)
print(x.shape)
# torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 197, 768])
我们依旧使用之前的图片作为测试样例,得到了上述输出。