版本:torch==1.12.1
计算机视觉算法中的上采用函数,一般都是采用nearest这种简单模式进行,PyTorch也是有相应的模块`nn.Upsample`来支持。但这个模块实在是不太与时俱进,慎用!
1. 不支持bf16
RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'BFloat16'
2. nn.Upsample 导致模型可复现性变差
参考这里。
---
解法就是,自定义Upsample模块,这里也是参考了这里:
import torch.nn as nn
class UpsampleDeterministic(nn.Module):
def __init__(self,upscale=2):
super(UpsampleDeterministic, self).__init__()
self.upscale = upscale
def forward(self, x):
'''
x: 4-dim tensor. shape is (batch,channel,h,w)
output: 4-dim tensor. shape is (batch,channel,self.upscale*h,self.upscale*w)
'''
return x[:, :, :, None, :, None]\
.expand(-1, -1, -1, self.upscale, -1, self.upscale)\
.reshape(x.size(0), x.size(1), x.size(2)\
*self.upscale, x.size(3)*self.upscale)