「详解」torch.nn.Fold和torch.nn.Unfold操作

Warning!!!
Fold:Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
UnFold:Currently, only 4-D input tensors (batched image-like tensors) are supported.

1、torch.nn.Unfold 滑动裁剪

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
  1. kernel_size:滑动窗口的size
  2. stride:空间维度上滑动的步长,Default: 1
  3. padding:在输入的四周 赋零填充. Default: 0
  4. dilation:空洞卷积的扩充率,Default: 1

在这里插入图片描述

torch.nn.Unfold按照官方的说法,既从一个batch的样本中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。

  1. 由上可知,torch.nn.Unfold的参数跟nn.Conv2d的参数很相似,即,kernel_size(卷积核的尺寸),dilation(空洞大小),padding(填充大小)和stride(步长)
  2. 官方解释中:unfold的输入为( N, C, H, W),其中N为batch_size,C是channel个数,H和W分别是channel的长宽。
    unfold 输出为 ( N , C × ∏ ( k e r n e l _ s i z e ) , L ) ( N, C\times \prod (kernel\_size), L) N,C×(kernel_size),L,其中 ∏ ( k e r n e l _ s i z e ) \prod (kernel\_size) (kernel_size) k e r n e l s i z e kernel_size kernelsize 长和宽的乘积
    L L L ( H , W ) (H,W) (H,W)根据 k e r n e l _ s i z e kernel\_size kernel_size 尺寸 滑动裁剪后,得到的区块的数量。

Shape:

  • Input: ( N , C , ∗ ) (N,C,∗) (N,C,)
  • Output: ( N , C × ∏ ( k e r n e l s i z e ) , L ) (N,C×∏(kernel_{size}),L) (N,C×(kernelsize),L)
    L = ∏ d ∈ ( 0 , 1 ) ⌊ f e a t u r e _ s i z e [ d ] + 2 p a d d i n g [ d ] − d i l a t i o n [ d ] ⋅ ( k e r n e l _ s i z e [ d ] − 1 ) − 1 s t r i d e _ [ d ] + 1 ⌋ L= \prod_{d \in(0,1)} \mathrm{ \Biggl\lfloor \frac{feature\_size[d]+ \color{red} 2padding[d]−\color{green} dilation[d] \cdot (kernel\_size[d]−1)\color{back}−1}{stride\_[d] } +1 \Biggr\rfloor} L=d(0,1)stride_[d]feature_size[d]+2padding[d]dilation[d](kernel_size[d]1)1+1
    ⌊ a r g ⌋ \lfloor \mathrm{arg} \rfloor arg 表示向下取整:函数返回不大于arg的最大整数值

inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold  = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)

nn.Unfold对输入channel的每一个 k e r n e l s i z e [ 0 ] × k e r n e l s i z e [ 1 ] kernel_{size[ 0 ]} \times kernel_{size[ 1 ]} kernelsize[0]×kernelsize[1]的滑动窗口区块做了展平操作。
在这里插入图片描述



2、torch.nn.Fold

torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)

torch.nn.Fold的操作与Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。

在这里插入图片描述

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())

Fold的操作通过设定output_size=(4, 4),完成与Unfold的互逆的操作。


3、Padding 填充操作解析

官方采用的描述词是 both sides, 博主通过代码验证了下,确实是四边全部赋零操作,而不能简单的翻译为两边,下面是代码与结果展示

在这里插入图片描述


4、代码解析

通过一下代码,我们可以看到 Unfold 与 Fold 是互逆过程。

>>> import torch
>>> inputs = torch.randn(1,2,4,4)
>>> unfold = torch.nn.Unfold(kernel_size=(2,2), stride=2)
>>> patches = unfold(inputs)
>>> fold = torch.nn.Fold(output_size=(4,4), kernel_size=(2,2), stride=2)
>>> out = fold(patches)

>>> inputs
tensor([[[[ 0.2220,  0.4331, -0.4789,  0.1313],
          [-1.0165, -0.7690, -0.7106,  0.0249],
          [-0.3132,  0.0441, -1.8581, -0.5766],
          [ 0.5753,  1.8645, -1.7966,  0.3177]],
          
         [[-0.1142,  0.5476, -0.9398, -0.5508],
          [-0.8906, -1.5367, -1.1093,  0.9651],
          [-1.4868, -0.7046,  1.1245, -2.0049],
          [-0.1741, -0.2840,  1.1057, -0.6320]]]])

>>> patches
tensor([[[ 0.2220, -0.4789, -0.3132, -1.8581],
         [ 0.4331,  0.1313,  0.0441, -0.5766],
         [-1.0165, -0.7106,  0.5753, -1.7966],
         [-0.7690,  0.0249,  1.8645,  0.3177],
         [-0.1142, -0.9398, -1.4868,  1.1245],
         [ 0.5476, -0.5508, -0.7046, -2.0049],
         [-0.8906, -1.1093, -0.1741,  1.1057],
         [-1.5367,  0.9651, -0.2840, -0.6320]]])

>>> out
tensor([[[[ 0.2220,  0.4331, -0.4789,  0.1313],
          [-1.0165, -0.7690, -0.7106,  0.0249],
          [-0.3132,  0.0441, -1.8581, -0.5766],
          [ 0.5753,  1.8645, -1.7966,  0.3177]],
          
         [[-0.1142,  0.5476, -0.9398, -0.5508],
          [-0.8906, -1.5367, -1.1093,  0.9651],
          [-1.4868, -0.7046,  1.1245, -2.0049],
          [-0.1741, -0.2840,  1.1057, -0.6320]]]])

>>> inputs == out
tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],
          
         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])

参考

  • 93
    点赞
  • 182
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 32
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 32
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ViatorSun

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值