unfold函数是一种图片的分块操作,可以提取出卷积核扫过的元素,并不做其它的运算。我们平时调用pytorch的卷积接口,使用的是Conv,是一个完全封装好的过程。使用unfold,相当于将kernel滑动扫过的数据提取出来。
x = torch.range(0,15).view(4,4).unsqueeze(dim=0).unsqueeze(dim=0) #[1,1,4,4]
'''
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])
'''
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x).transpose(-1,-2) #[1,1*4,4] [B,C*kernelsize[0]*kernelsize[1],patch_num]
'''
tensor([[[ 0., 1., 4., 5.],
[ 2., 3., 6., 7.],
[ 8., 9., 12., 13.],
[10., 11., 14., 15.]]])
'''
就像ViT模型的patch embedding操作一样,可以得到一个patch中的所有元素。
x = torch.range(0,15).view(4,4).unsqueeze(dim=0).unsqueeze(dim=0)
print(x)
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x)
'''
tensor([[[ 0., 2., 8., 10.],
[ 1., 3., 9., 11.],
[ 4., 6., 12., 14.],
[ 5., 7., 13., 15.]]])
'''
unfold(x)的输出行就是卷积核当前位置在这个矩阵上所扫过的元素(例如0、2、8、10就是每次卷积核最左上角扫过的元素)。按照列来看就是在所有channel上的每一块元素的concat。(例如第一列就是所有channel上第一块里面的元素的拼接)
x = torch.range(0,31).view(2,4,4).unsqueeze(dim=0)
print(x)
'''
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]],
[[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.]]]])
'''
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x)
'''
tensor([[[ 0., 2., 8., 10.],
[ 1., 3., 9., 11.],
[ 4., 6., 12., 14.],
[ 5., 7., 13., 15.],
[16., 18., 24., 26.],
[17., 19., 25., 27.],
[20., 22., 28., 30.],
[21., 23., 29., 31.]]])
'''
因此可以用unfold来实现一个ViT中的patch_embedding
def img2emb_naive(img,patch_size,weight):
# image size: [B,C,H,W]
patch = F.unfold(img,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)#图片分块
print(patch.shape) #1,4,12 12=3*2*2 也就是对应ViT的768
patch_embedding = patch@weight #weight就是一个可学习的参数,将原本的所有像素转换为embed的维度
return patch_embedding
#test img2emb_naive
x = torch.randn(1,3,4,4)
patchsize = 2
emb_dim = 2
weight = torch.randn(3*patchsize*patchsize,emb_dim)
patch_embedding = img2emb_naive(x,patchsize,weight)
patch_embedding.shape #torch.Size([1, 4, 2])