pytorch 0.3.0升级到0.4.0之后部分之前保存的模型有些不能直接使用
判断版本:
import torch print(torch.__version__)
0.4.0相比0.3.0代码中需要修改的部分:
1.norm_layer
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)改为:
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
2.指定gpu
if len(gpu_ids) > 0: netG.cuda(device_id=gpu_ids[0])
改为:
if len(gpu_ids) > 0: netG.cuda(0)
3.
不需要梯度计算:
input_label = Variable(input_label,volatile=False)
改为:
with torch.no_grad():
input_label = Variable(input_label)
英文文档:https://pytorch.org/2018/04/22/0_4_0-migration-guide.html