直接使用x.unfold()
import torch
h,w = 3,3
x = torch.arange(20).float()
print(x,x.shape)
x2d = x.reshape(1,1,4,5)
print(x2d,x2d.shape)
r1 = x2d.unfold(2,3,1)
print(r1,r1.shape)
r2 = x2d.unfold(3,3,1)
print(r2,r2.shape)
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19.]) torch.Size([20])
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.]]]]) torch.Size([1, 1, 4, 5])
tensor([[[[[ 0., 5., 10.],
[ 1., 6., 11.],
[ 2., 7., 12.],
[ 3., 8., 13.],
[ 4., 9., 14.]],
[[ 5., 10., 15.],
[ 6., 11., 16