from torch import nn
import torch
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
#
# embed_dim表示切好的图片拉成一维向量后的特征长度
#
# 图像共切分为N = HW/P^2个patch块
# 在实现上等同于对reshape后的patch序列进行一个PxP且stride为P的卷积操作
# output = {[(n+2p-f)/s + 1]向下取整}^2
# 即output = {[(n-P)/P + 1]向下取整}^2 = (n/P)^2
#
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
print(x.shape)
x = x.flatten(2)
print(x.shape)
x = x.transpose(1, 2)
return x # x.shape is [8, 196, 768]
# 输入
a = torch.Tensor(4, 3, 224, 224)
print(a.shape)
emb = PatchEmbed()
x = emb(a)
print(x.flatten(2).shape)
- 图像切片,将(batch_size, channel, Height, wide) 转为 (batchsize, n, embed_dim)
n = W ∗ H ∗ C P 2 n = \frac{ W*H*C} {P^2} n=P2W∗H∗C 这里等于196
e m b e d _ d i m = P 2 ∗ C embed\_dim = P^2*C embed_dim=P2∗C 这里等于768
img_size=224, patch_size=16, in_chans=3, embed_dim=768
将224x224*3的图像打成196个patch, 每个patch的size为768 , 即768维一维张量.
对应196sentence, 每个sentence的编码映射为768维的特征向量
咋眼看说去像是用大卷积核操作, 实际上只是借用它的方式, 最后在图像长宽维展平为一位再和第768个patch维对调实现了图像embedding.
**patch ebedding模块: (b, h, w) --> (b, n, embed_dim) 用推导公式, n随图片尺寸变, embed_dim根据设定的patch尺寸和图像通道变.可以理解为一张图像的空间几何信息转换为了语义信息. 这样做的目的是利用Transforer. **
- 自注意力 和MLP-Conv_MLP
- Patch Merging
# patch merging
import torch
import torch.nn as nn
import math
import numpy as np
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, c=3, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.c = c
self.reduction = nn.Linear(4*c*dim, 2*c*dim, bias=False) # 默认的线性层维度dim=1
self.norm = norm_layer(4*c*dim)
def forward(self, x):
"""
x: (B, H*W, C) , L = H*W: Resolution of input feature
"""
B, L, c = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, c)
print('-------------------------------------------------------\n将input拆分为四个feature map x0, x1, x2, x3')
print('input(B, H, W, C): ', x.shape)
print('-------------------------------------------------------')
# 在行和列方向上间隔1选取元素, 0::3相隔2选取元素, 0::4 相隔3选取元素, 0::n 相隔n-1选取元素
# :: 前后对应了索引,总是选取两个元素
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
print('x0', x0.shape)
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
print('x1', x1.shape)
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
print('x2', x2.shape)
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
print('x3', x3.shape)
# 拼接到一起作为一整个张量
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
print('-------------------------------------------------------')
print('拼接整个张量后:', x.shape)
x = x.view(B, -1, 4*c) # B H/2*W/2 4*C
print('-------------------------------------------------------')
print('合并行和列后:', x.shape)
x = self.norm(x) # 归一化操作
print('-------------------------------------------------------')
print('归一化操作后:', x.shape)
x = self.reduction(x) # 通道降为原来的1/n维, n = 2
print('-------------------------------------------------------')
print('通道降低2倍后:', x.shape)
print('-------------------------------------------------------')
return x
if __name__ == "__main__":
# x = np.array([[0, 2, 0, 2], [1, 3, 1, 3], [0, 2, 0, 2], [1, 3, 1, 3]]) # (1, 4, 4)
# x = torch.from_numpy(x)
# x = x.view(1, 4 * 4, 1) # (1, 16, 1)
# x = x.to(torch.float32) # ()
model = PatchMerging(1)
# print('--------------------------')
x = torch.FloatTensor(32, 512*512, 3) # 随意定义一个tensor, 调整通道c, 更改PatchMerging同时初始化c
print('-------------------------------------------------------')
print('input(B, W*H, 1): ', x.shape)
y = model(x)
merging模块: 通道: c – > 2c. 通过该方式实现图像空间信息交流(相隔元素取并merging, 不损失特征的方式实现了信息提取, 提取到更深图像语义信息. 更为重要的是, 实现了无损失三维到二维的切换. 如果对应于自然语言, 则是降维了句子语义的特征空间, 这种降维的方式降句子的不同语义特征进行了交流融合
通道压缩-降维, 并没有对特征进行丢弃.