出现Bug的地方:
x = self.Down_wt(x)
x = F.interpolate(x, scale_factor=(2, 2), mode='nearest').to(torch.float32)
修改后:
x = self.Down_wt(x)
with torch.no_grad():
x = F.interpolate(x, scale_factor=(2, 2), mode='nearest').to(torch.float32)
出现Bug的地方:
x = self.Down_wt(x)
x = F.interpolate(x, scale_factor=(2, 2), mode='nearest').to(torch.float32)
修改后:
x = self.Down_wt(x)
with torch.no_grad():
x = F.interpolate(x, scale_factor=(2, 2), mode='nearest').to(torch.float32)