unfold的作用就是手动实现的滑动窗口操作,也就是只有卷,没有积;不过相比于nn.functional中的unfold而言,其窗口的意味更浓,只能是一维的,也就是不存在类似2×2窗口的说法。
ret = x.unfold(dim, size, step)
- dim:int,表示需要展开的维度(可以理解为窗口的方向)
- size:int,表示滑动窗口大小
- step:int,表示滑动窗口的步长
例如,对于一张5×5的如下特征图:
[[[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[ 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20],
[ 21, 22, 23, 24, 25]]]]
对其进行长度为3,步长为1的滑动窗口操作,分两种情况:
窗口为3×1:
1 -> 2 -> 3
6 7 8
11 12 13
窗口为1×3:
1 -> 2 -> 3
2 3 4
3 4 5
如果是第一种情况,此时窗口的方向dim=0(第0维为3,其他维为1);如果是第二种情况,此时窗口的方向dim=1(第1维为3,其他维为1)。
代码如下:
import torch
x = torch.Tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[ 11, 12, 13, 14, 15],
[ 16, 17, 18, 19, 20],
[ 21, 22, 23, 24, 25]])
x = x.unfold(0, 3, 1)
print(x)
print(x.size())
输出:
tensor([[[ 1., 6., 11.],
[ 2., 7., 12.],
[ 3., 8., 13.],
[ 4., 9., 14.],
[ 5., 10., 15.]],
[[ 6., 11., 16.],
[ 7., 12., 17.],
[ 8., 13., 18.],
[ 9., 14., 19.],
[10., 15., 20.]],
[[11., 16., 21.],
[12., 17., 22.],
[13., 18., 23.],
[14., 19., 24.],
[15., 20., 25.]]])
torch.Size([3, 5, 3])
其中,每一行表示滑动窗口每次移动所覆盖的内容。例如在这里滑动窗口可以从左往右滑动5次,因此每维有5行;而每滑动5次需要向下移动一次,一共需要移动2次,因此有2+1=3维。