torch.nn.Unfold 滑动裁剪
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
#kernel_size:滑动窗口的size
#stride:空间维度上滑动的步长,Default: 1
#padding:在输入的四周 赋零填充. Default: 0
#dilation:空洞卷积的扩充率,Default: 1
torch.nn.Unfold按照官方的说法,既从一个batch的样本中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。
由上可知,torch.nn.Unfold的参数跟nn.Conv2d的参数很相似,即,kernel_size(卷积核的尺寸),dilation(空洞大小),padding(填充大小)和stride(步长)
官方解释中:unfold的输入为( N, C, H, W),其中N为batch_size,C是channel个数,H和W分别是channel的长宽。
inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)
torch.nn.Fold
torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)
torch.nn.Fold的操作与Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。
fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())
Fold的操作通过设定output_size=(4, 4),完成与Unfold的互逆的操作。
Padding 填充操作解析
官方采用的描述词是 both sides, 博主通过代码验证了下,确实是四边全部赋零操作,而不能简单的翻译为两边,下面是代码与结果展示
代码解析
通过一下代码,我们可以看到 Unfold 与 Fold 是互逆过程。
>>> import torch
>>> inputs = torch.randn(1,2,4,4)
>>> unfold = torch.nn.Unfold(kernel_size=(2,2), stride=2)
>>> patches = unfold(inputs)
>>> fold = torch.nn.Fold(output_size=(4,4), kernel_size=(2,2), stride=2)
>>> out = fold(patches)
>>> inputs
tensor([[[[ 0.2220, 0.4331, -0.4789, 0.1313],
[-1.0165, -0.7690, -0.7106, 0.0249],
[-0.3132, 0.0441, -1.8581, -0.5766],
[ 0.5753, 1.8645, -1.7966, 0.3177]],
[[-0.1142, 0.5476, -0.9398, -0.5508],
[-0.8906, -1.5367, -1.1093, 0.9651],
[-1.4868, -0.7046, 1.1245, -2.0049],
[-0.1741, -0.2840, 1.1057, -0.6320]]]])
>>> patches
tensor([[[ 0.2220, -0.4789, -0.3132, -1.8581],
[ 0.4331, 0.1313, 0.0441, -0.5766],
[-1.0165, -0.7106, 0.5753, -1.7966],
[-0.7690, 0.0249, 1.8645, 0.3177],
[-0.1142, -0.9398, -1.4868, 1.1245],
[ 0.5476, -0.5508, -0.7046, -2.0049],
[-0.8906, -1.1093, -0.1741, 1.1057],
[-1.5367, 0.9651, -0.2840, -0.6320]]])
>>> out
tensor([[[[ 0.2220, 0.4331, -0.4789, 0.1313],
[-1.0165, -0.7690, -0.7106, 0.0249],
[-0.3132, 0.0441, -1.8581, -0.5766],
[ 0.5753, 1.8645, -1.7966, 0.3177]],
[[-0.1142, 0.5476, -0.9398, -0.5508],
[-0.8906, -1.5367, -1.1093, 0.9651],
[-1.4868, -0.7046, 1.1245, -2.0049],
[-0.1741, -0.2840, 1.1057, -0.6320]]]])
>>> inputs == out
tensor([[[[True, True, True, True],
[True, True, True, True],
[True, True, True, True],
[True, True, True, True]],
[[True, True, True, True],
[True, True, True, True],
[True, True, True, True],
[True, True, True, True]]]])