from collections import OrderedDict
def convert_state_dict(state_dict):
"""Converts a state dict saved from a dataParallel module to normal
module state_dict inplace
:param state_dict is the loaded DataParallel model_state
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
使用
#load model
if cuda:
state = convert_state_dict(torch.load(args.checkpoint, map_location='cuda:0')['G_state'])
netG.load_state_dict(state)
netG = netG.cuda()
else:
state = convert_state_dict(torch.load(args.checkpoint, map_location='cpu')['G_state'])
netG.load_state_dict(state)