from collections import OrderedDict
state_dict = torch.load(para_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
pytorch单GPU模型加载多GPU保存的参数
最新推荐文章于 2023-07-05 23:18:52 发布