from collections import OrderedDict
state_dict = torch.load(para_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
pytorch单GPU模型加载多GPU保存的参数
最新推荐文章于 2022-06-28 14:07:50 发布