import torch
a = torch.tensor([[1, 1, 1, 1.0],
[2, 2, 2, 2],
[1, 1, 1, 1],
[1, 1, 1, 1]])
b, indices = torch.nn.MaxPool2d(2, 2, return_indices=True)(a.unsqueeze(0).unsqueeze(0))
print(b)
b = torch.nn.MaxUnpool2d(2, 2)(b, indices)
print(b)