Pytorch nn.Fold()的简单理解与用法

本文详细介绍了PyTorch中的Unfold和Fold操作,通过实例展示了它们如何对特征图进行滑动窗口处理。Unfold将特征图拆分成小块,而Fold则将这些块重组回原图。当stride等于kernel_size时,这两个操作互逆。文章通过代码演示了这两个操作的过程,并解释了在有重叠情况下的不适用性。
摘要由CSDN通过智能技术生成

官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html

这个东西基本上就是绑定Unfold使用的。实际上,在没有overlapping、参数相同的情况下,其与Unfold操作是互逆的。

官方对该函数作用的描述如下:
…This operation combines these local blocks into the large output tensor by summing the overlapping values…
这一操作通过对重叠的数值进行求和,将这些局部块结合到大的输出tensor中

说的比较含糊,那我们先上代码试一下unfold。对于一张1×1×4×4的特征图:

[[[[  1,  2,  3,  4],
   [  5,  6,  7,  8],
   [  9, 10, 11, 12],
   [ 13, 14, 15, 16]]]]

对其进行2×2,stride=2的滑动窗口操作以unfold,实现如下:

import torch
import torch.nn as nn
x = torch.Tensor([[[[  1,  2,  3,  4],
   					[  5,  6,  7,  8], 
   					[  9, 10, 11, 12],
   					[ 13, 14, 15, 16]]]])
unfold = nn.Unfold((2,2), stride=2)
print(x)
print(x.size())

输出unfold结果为:

tensor([[[ 1.,  3.,  9., 11.],
         [ 2.,  4., 10., 12.],
         [ 5.,  7., 13., 15.],
         [ 6.,  8., 14., 16.]]])
torch.Size([1, 4, 4])

再来看fold。前面我们看到,fold做的其实就是利用 h × w h×w h×w的核进行滑动窗口操作,然后将每次滑动得到的结果展平成一个列向量,逐个填充至结果中。那么unfold的话,做的工作就是处理fold得到的列向量。具体而言,unfold每次读取一个列向量,然后将其reshape回一个 h × w h×w h×w的块,再填回结果中。这时候就涉及一个问题,如果stride较小的话,reshape得到的块再填回结果时是会有overlapping的,因此只有在无overlapping(对于本例,需要stride=2)的情况下unfold与fold才可逆。

现在我们继续接着上面的例子,从unfold结果中提取第一列数据,将其reshape为2×2:

1 2 
5 6

然后将其填充到3×3结果中,有:

[[[[0+1, 0+2, 0],
   [0+5, 0+6, 0],
   [  0,   0, 0]]]]

继续提取第二列数据,将其reshape为2×2:

3 4
7 8

然后将其填充到3×3结果中。需要注意的是,由于stride=1,因此此时用于填充结果的kernel只会向右移一格,导致结果填充重叠:

[[[[1, 2+3, 0+4],
   [5, 6+7, 0+8],
   [0,   0,   0]]]]

继续提取第三列数据,将其reshape为2×2:

9  10
13 14

然后将其填充到3×3结果中:

[[[[   1,     5, 4],
   [ 5+9, 13+10, 8],
   [0+13,  0+14, 0]]]]

提取第四列数据,将其reshape为2×2:

11 12 
15 16

将其填充到3×3结果中,得到最后结果:

[[[[ 1,     5,    4],
   [14, 23+11, 8+12],
   [13, 14+15, 0+16]]]]

完整编码实现如下:

import torch
import torch.nn as nn

x = torch.Tensor([[[[  1,  2,  3,  4],
   					[  5,  6,  7,  8],
   					[  9, 10, 11, 12],
   					[ 13, 14, 15, 16]]]])

print(x)
unfold = nn.Unfold((2,2), stride=2)
fold = nn.Fold(kernel_size=(2,2), stride=1, output_size=(3,3))
x = unfold(x) 
print(x)
print(x.size())
x = fold(x)
print(x)
print(x.size())
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值