官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html
这个东西基本上就是绑定Unfold使用的。实际上,在没有overlapping、参数相同的情况下,其与Unfold操作是互逆的。
官方对该函数作用的描述如下:
…This operation combines these local blocks into the large output tensor by summing the overlapping values…
这一操作通过对重叠的数值进行求和,将这些局部块结合到大的输出tensor中
说的比较含糊,那我们先上代码试一下unfold。对于一张1×1×4×4的特征图:
[[[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]]
对其进行2×2,stride=2的滑动窗口操作以unfold,实现如下:
import torch
import torch.nn as nn
x = torch.Tensor([[[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
unfold = nn.Unfold((2,2), stride=2)
print(x)
print(x.size())
输出unfold结果为:
tensor([[[ 1., 3., 9., 11.],
[ 2., 4., 10., 12.],
[ 5., 7., 13., 15.],
[ 6., 8., 14., 16.]]])
torch.Size([1, 4, 4])
再来看fold。前面我们看到,fold做的其实就是利用 h × w h×w h×w的核进行滑动窗口操作,然后将每次滑动得到的结果展平成一个列向量,逐个填充至结果中。那么unfold的话,做的工作就是处理fold得到的列向量。具体而言,unfold每次读取一个列向量,然后将其reshape回一个 h × w h×w h×w的块,再填回结果中。这时候就涉及一个问题,如果stride较小的话,reshape得到的块再填回结果时是会有overlapping的,因此只有在无overlapping(对于本例,需要stride=2)的情况下unfold与fold才可逆。
现在我们继续接着上面的例子,从unfold结果中提取第一列数据,将其reshape为2×2:
1 2
5 6
然后将其填充到3×3结果中,有:
[[[[0+1, 0+2, 0],
[0+5, 0+6, 0],
[ 0, 0, 0]]]]
继续提取第二列数据,将其reshape为2×2:
3 4
7 8
然后将其填充到3×3结果中。需要注意的是,由于stride=1,因此此时用于填充结果的kernel只会向右移一格,导致结果填充重叠:
[[[[1, 2+3, 0+4],
[5, 6+7, 0+8],
[0, 0, 0]]]]
继续提取第三列数据,将其reshape为2×2:
9 10
13 14
然后将其填充到3×3结果中:
[[[[ 1, 5, 4],
[ 5+9, 13+10, 8],
[0+13, 0+14, 0]]]]
提取第四列数据,将其reshape为2×2:
11 12
15 16
将其填充到3×3结果中,得到最后结果:
[[[[ 1, 5, 4],
[14, 23+11, 8+12],
[13, 14+15, 0+16]]]]
完整编码实现如下:
import torch
import torch.nn as nn
x = torch.Tensor([[[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]]]])
print(x)
unfold = nn.Unfold((2,2), stride=2)
fold = nn.Fold(kernel_size=(2,2), stride=1, output_size=(3,3))
x = unfold(x)
print(x)
print(x.size())
x = fold(x)
print(x)
print(x.size())