我们知道,transformer要求将图片分为patch,然后输入网络进行计算,那么我们就需要将二维的图片处理成一维的embeding形式,今天我来给大家介绍一下图片处理的思路。
我们演示一下处理下面这张图片
我们将图片按照16*16的大小进行分片,得到的结果如下图所示
接下来我们需要将patch变成tensor。在此之前先介绍一下传统CNN图片处理和transformer图片处理之间的区别
我们可以看到,传统CNN图片处理得到的向量是三维的,而transformer图片处理得到的向量是二维的,其中num表示一张图片分片数量(也就是分成多少个patch),第二个维度中patch*patch表示每个patch的面积,channel表示通道数。
当我们训练网络的时候,通常需要将数据加载成batch的形式,一个batch里面通常包含多张图片,所以数据格式如下所示
也就是说,transformer送入网路进行计算的数据是三维的,而传统CNN送入网络进行计算的数据是四维的,这也是CNN和transformer数据加载的主要区别。
下面就贴一段数据处理的演示代码,你可以按照这个代码的思路去写数据加载器。
import torch
from PIL import Image
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
class ImgFactory(object):
def __init__(self, patch=16):
super(ImgFactory, self).__init__()
self.patch = patch
self.im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def getImagePatch(self, filename):
img = Image.open(filename)
width, height = img.size
num_patch_w = width // self.patch
num_patch_h = height // self.patch
patch_list = []
num = 1
for i in range(num_patch_h):
for j in range(num_patch_w):
s_y = i*self.patch
s_x = j*self.patch
box = (s_x, s_y, self.patch+s_x, self.patch+s_y)
region = img.crop(box)
patch_list.append(region)
plt.subplot(num_patch_h, num_patch_w, num), plt.imshow(region), plt.axis("off")
num = num + 1
plt.savefig("patch.png")
for i in range(len(patch_list)):
patch_list[i] = self.im_tfs(patch_list[i])
patch_list[i] = patch_list[i].view(1,-1)
seq = torch.cat(patch_list, dim=0)
return seq
if __name__ == "__main__":
factory = ImgFactory()
seq = factory.getImagePatch("a.png")
print(seq.shape)
输出结果是一张图片加载成tensor的格式