torch.nn.PixelShuffle()演示
import torch
import torch.nn as nn
r = 2 # 上采样倍率
PS = nn.PixelShuffle(r) # 初始化亚像素卷积操作
x = torch.arange(3*4*9).reshape(1, 3*(r**2),3, 3)
print(f'*****************************************')
print(f'input is \n{x}, and size is {x.size()}')
y = PS(x) # 亚像素上采样
print(f'*****************************************')
print(f'output is \n{y}, and size is {y.size()}')
print(f'*****************************************')
print(f'upscale_factor is {PS.extra_repr()}')
print(f'*****************************************')
运行结果为: