torch.nn.Fold的操作与Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。
Fold — PyTorch 1.10.0 documentationhttps://pytorch.org/docs/stable/generated/torch.nn.Fold.html?highlight=fold#torch.nn.Fold
Code:
>>> import torch
>>> x=torch.arange(16,dtype=torch.float)
>>> x
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.])
>>> x=x.view(1,1,4,4)
>>> x
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])
>>> unfold=torch.nn.Unfold(kernel_size=2,stride=2)
>>> y=unfold(x)
>>> y.size()
torch.Size([1, 4, 4])
>>> y
tensor([[[ 0., 2., 8., 10.],
[ 1., 3., 9., 11.],
[ 4., 6., 12., 14.],
[ 5., 7., 13., 15.]]])
>>> fold=torch.nn.Fold(output_size=(4,4),kernel_size=2,stride=2)
>>> z=fold(y)
>>> z.shape
torch.Size([1, 1, 4, 4])
>>> z
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]])