[pytorch]unfold操作

unfold函数是一种图片的分块操作,可以提取出卷积核扫过的元素,并不做其它的运算。我们平时调用pytorch的卷积接口,使用的是Conv,是一个完全封装好的过程。使用unfold,相当于将kernel滑动扫过的数据提取出来。

x = torch.range(0,15).view(4,4).unsqueeze(dim=0).unsqueeze(dim=0) #[1,1,4,4]
'''
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
'''
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x).transpose(-1,-2) #[1,1*4,4]  [B,C*kernelsize[0]*kernelsize[1],patch_num]
'''
tensor([[[ 0.,  1.,  4.,  5.],
         [ 2.,  3.,  6.,  7.],
         [ 8.,  9., 12., 13.],
         [10., 11., 14., 15.]]])
          '''

就像ViT模型的patch embedding操作一样,可以得到一个patch中的所有元素。

x = torch.range(0,15).view(4,4).unsqueeze(dim=0).unsqueeze(dim=0)
print(x)
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x)
'''
tensor([[[ 0.,  2.,  8., 10.],
         [ 1.,  3.,  9., 11.],
         [ 4.,  6., 12., 14.],
         [ 5.,  7., 13., 15.]]])
'''

unfold(x)的输出行就是卷积核当前位置在这个矩阵上所扫过的元素(例如0、2、8、10就是每次卷积核最左上角扫过的元素)。按照列来看就是在所有channel上的每一块元素的concat。(例如第一列就是所有channel上第一块里面的元素的拼接)

x = torch.range(0,31).view(2,4,4).unsqueeze(dim=0)
print(x)
'''
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.]]]])
'''
unfold = nn.Unfold(kernel_size=2,stride=2)
unfold(x)
'''
tensor([[[ 0.,  2.,  8., 10.],
         [ 1.,  3.,  9., 11.],
         [ 4.,  6., 12., 14.],
         [ 5.,  7., 13., 15.],
         [16., 18., 24., 26.],
         [17., 19., 25., 27.],
         [20., 22., 28., 30.],
         [21., 23., 29., 31.]]])
'''

因此可以用unfold来实现一个ViT中的patch_embedding

def img2emb_naive(img,patch_size,weight):
    # image size: [B,C,H,W]
    patch = F.unfold(img,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)#图片分块
    print(patch.shape)   #1,4,12  12=3*2*2 也就是对应ViT的768 
    patch_embedding = patch@weight #weight就是一个可学习的参数,将原本的所有像素转换为embed的维度
    return patch_embedding

#test img2emb_naive
x = torch.randn(1,3,4,4)

patchsize = 2
emb_dim = 2
weight = torch.randn(3*patchsize*patchsize,emb_dim)

patch_embedding = img2emb_naive(x,patchsize,weight)
patch_embedding.shape #torch.Size([1, 4, 2])
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值