Pytorch学习之图片位置编码
前提
在VisionTransformer模型中,使用一个二维的卷积核,将图片展开成一个patch序列
patch_embed = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
通过训练一个位置编码参数来学习记录图片的位置信息
num_patches
为图片展开的patch数目,加一是包含了cls_token,详细请阅读VisionTransformer论文
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
部分代码实现
# 对图片进行展开操作
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x
x = PatchEmbed(img)
# 添加位置信息
x = x + pos_embed
位置编码转换
对于一个已经训练好的VisionTransformer模型,如何将学习的位置信息转换到一张任意分辨率的图片上
# 代码引用自https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1 # ,减去cls_token,得到输入图片的patch数量
N = self.pos_embed.shape[1] - 1 # 原模型训练时,patch数量
if npatch == N and w == h: #输入图片符合训练模型时的图片大小,可直接使用训练好的位置编码信息
return self.pos_embed
# 将位置编码转换到一张任意分辨率的图片上
class_pos_embed = self.pos_embed[:, 0] # 提取cls_token的位置编码信息
patch_pos_embed = self.pos_embed[:, 1:] # 提取图片patch序列的位置编码信息
dim = x.shape[-1]
# 对图片进行patch分割
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
w0, h0 = w0 + 0.1, h0 + 0.1
# 根据给定的size或scale_factor参数来对输入进行下/上采样
# 将原本在224*224训练得到的位置编码信息,转换到任意大小图片上
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), # 指定输出为输入的多少倍数。
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
# img已通过transform转换,并添加batch信息
B, c, w, h = img.shape
x = PatchEmbed(img)
# 添加cls_token,保持与模型一致
x = torch.cat((cls_tokens, x), dim=1)
# 添加位置信息
x = x + interpolate_pos_encoding(self, x, w, h)