pytorch中的nn.Unfold()函数和fold(函数详解

1.nn.Unfold()函数

描述:pytorch中的nn.Unfold()函数,在图像处理领域,经常需要用到卷积操作,但是有时我们只需要在图片上进行滑动的窗口操作,将图片切割成patch,而不需要进行卷积核和图片值的卷积乘法操作。这是就需要用到nn.Unfold()函数,该函数是从一个batch图片中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。

torch.nn.Unfold(kernel_size,dilation=1,paddding=0,stride=1)

该函数的输入是(bs,c,h,w),其中bs为batch-size,C是channel的个数。
而该函数的输出是(bs,Cxkernel_size[0]xkernel_size[1],L)其中L是特征图或者图片的尺寸根据kernel_size的长宽滑动裁剪后得到的多个patch的数量。

import torch.nn as nn
import torch
batches_img=torch.rand(1,2,4,4)#模拟图片数据(bs,2,4,4),通道数C为2
print("batches_img:\n",batches_img)

nn_Unfold=nn.Unfold(kernel_size=(2,2),dilation=1,padding=0,stride=2)
patche_img=nn_Unfold(batches_img)
print("patche_img.shape:",patche_img.shape)
print("patch_img:\n",patche_img)

在这里插入图片描述
该方法的主要应用场景是将图片切割成不同的patch,配合一下代码实现

#上面的代码能够获取到patch_img,(bs,C*K*K,L),L代表的是将每张图片分割成多少块
reshape_patche_img=patche_img.view(batches_img.shape[0],batches_img.shape[1],2,2,-1)
print(reshape_patche_img.shape)#[bs, C, k, k, L]
reshape_patche_img=reshape_patche_img.permute(0,4,1,2,3)#[N, L, C, k, k]
print(reshape_patche_img.shape)

结果:
在这里插入图片描述

2.nn.Fold()函数

该函数是nn.Unfold()函数的逆操作。

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())

在这里插入图片描述

  • 31
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值