在Pytorch中可以用torch.unfold, torch.cat和torch.transpose的组合实现im2col操作.
TAKE AWAY:
stride = (1, 1)
kernel_size = (3, 3)
x = torch.arange(0, 25).resize_(5, 5)
y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)
下面以一个简单小矩阵举例详细说明单通道im2col操作:
x = torch.arange(0, 25).resize_(5, 5)
print(x)
0 1 2 3 4
5 6 7 8 9
10 11 12 13 14
15 16 17 18 19
20 21 22 23 24
[torch.FloatTensor of size 5x5]
定义卷积核大小和步长
kernel_size = (3,