Vit step by step -- Image 2 patch

一、由头

在这里我们开一个新坑,这个新坑主要用来讲解Vision Transformer的详细实现细节。

关于Vision Transformer我们预计开四篇文章把它讲透,每篇文章我们都会代码和输入的样例。

关于Vision Transformer的第一篇文章,我们就先实现一下Vision Transformer的第一步——如果把图片变成patches然后输入到Transformer的Encoder当中。

图片来源于Vit原文

首先我们再回顾一下Vit的原文的标题:

图片来源于Vit原文

我们这里把红框当中的英文着重翻译一下,译为:一张图片可以等价为形状为16×16的单词。

这句话初看有点抽象(或者说我翻译的有点抽象),但实际上过程是这样的:

在Vit当中,每张输入图片的大小都被resize为224×224,此时我们用一个kernel_size为14,stride为14的卷积核去对该图进行操作,卷积完之后,Feature_map的形状就变为16×16×channel_nums,此时的16×16就和标题呼应上了。

接下来,我们就用代码来对应红框中的每一步操作:

二、分步代码

import torch 
import torch.nn as nn 
from PIL import Image 
from torchvision import transforms 
img = Image.open('chart_example_1.png').convert('RGB') #随便读取一个图片 
img = img.resize((224, 224)) #转换成224 * 224的大小 
img_tensor = transforms.ToTensor()(img) #转换成tensor类型 img_tensor = img_tensor.unsqueeze(0) #这里我们在前面增加一个维度表示batch_size 
print(img_tensor.shape) 
# 
torch.Size([1, 3, 224, 224])

上述代码用来示例图片。我们将加载完毕的图片打印出来后可以看到形状为1×3×224×224,其中1是指批量大小,3是指图片的RGB通道数,224×224代表图片的长宽。

## 将图片转换成
patch in_channels = img_tensor.shape[1] #提取图片的输入通道 
embed_dim = 768 #设置图片的输出通道,即embedding_dim 
patch_size = 16 #设置patch_size,对应的是卷积核的大小和卷积核stride 
num_patches = (224 // patch_size) * (224 // patch_size) #这里算出卷积之后的Feature_map形状 patch_conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) #根据参数设置对应的卷积核 
patch_tensor = patch_conv(img_tensor) #得到卷积之后的图片 print(patch_tensor.shape) #打印patch的形状 patch_tensor = patch_tensor.flatten(2) #将Feature_map的长和宽拉平,当成token patch_tensor = patch_tensor.permute(0, 2, 1) #将embedding维度和token维度互换位置 print(patch_tensor.shape) #打印形状 
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) #设计cls_token,embedding_dim和patch的embedding保持一致 
pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim)) #pos_embedding,embed_dim和patch的embedding保持一致,并且第二个维度上的大小为num_patches + 1 #这里+1是因为我们的patch会加上一个cls_token 
pos_drop = nn.Dropout(0.1) #设置一个drop_out 
patch_tensor = torch.cat((cls_token, patch_tensor), dim=1) #把cls_token和patch_tensor在第二个维度上拼接起来 
patch_tensor += pos_embed #patch_tensor和pos_embeding相加 
patch_tensor = pos_drop(patch_tensor) 
print(patch_tensor.shape) 
# 
torch.Size([1, 768, 14, 14]) 
torch.Size([1, 196, 768]) 
torch.Size([1, 197, 768])

三、集成代码(写成class)

上述代码是我们一步一步去实现的patch操作,显得非常不规范。下面我们将这些代码进行集成,写成一个class,提高代码的复用性和观赏性。

import torch.nn as nn
import torch 
class patch_embedding(nn.Module): 
    def __init__(self, in_channels, embed_dim, num_patches, patch_size):                         
        super(patch_embedding, self).__init__() 
        self.patcher = nn.Sequential( nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), nn.Flatten(2) ) 
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))     
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))     
        self.pos_drop = nn.Dropout(0.1) 
    def forward(self, x): 
        x = self.patcher(x).permute(0, 2, 1) 
        print(x.shape) 
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 
        print(cls_tokens.shape) 
        x = torch.cat((cls_tokens, x), dim=1) 
        x = self.pos_drop(x + self.pos_embed) return x

这里我们把刚才所有的torch函数都集成在一个class当中,大家可以对一一对应一下代码。

import torch 
import numpy as np 
import torch.nn as nn 
from PIL import Image 
from torchvision import transforms 
img = Image.open('chart_example_1.png').convert('RGB') 
img = img.resize((224, 224)) #转换成224 * 224的大小 
img_tensor = transforms.ToTensor()(img) #转换成tensor类型 img_tensor = img_tensor.unsqueeze(0) 
print(img_tensor.shape) 
x = patch_embedding(in_channels = 3, embed_dim = 768, num_patches = 196, patch_size = 16)(img_tensor) 
print(x.shape) 
# torch.Size([1, 3, 224, 224]) 
torch.Size([1, 196, 768]) 
torch.Size([1, 1, 768]) 
torch.Size([1, 197, 768])

我们依旧使用之前的图片作为测试样例,得到了上述输出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会震pop的码农

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值