保存网络权重:
torch.save(G_AB.state_dict(), 'G_AB.pth')
单GPU加载权重:
G_AB.load_state_dict(torch.load(G_AB.pth))
CPU加载单GPU权重:
g.load_state_dict(torch.load('./save/G.pth',map_location='cpu')) #load weight
如果训练是多GPU,测试时用单GPU加载权重:
from collections import OrderedDict
state_dictBA = torch.load(opt.pth)
# create new OrderedDict that does not contain `module.`
new_state_dictBA = OrderedDict()
for k, v in state_dictBA.items():
name = k[7:] # remove `module.`
new_state_dictBA[name] = v
G_BA.load_state_dict(new_state_dictBA)