state_dict = torch.load(config["pretrain_snapshot"],map_location=lambda storage, loc: storage)
参考:https://blog.csdn.net/hardbird123/article/details/80549815
Pytorch之GPU模型加载在CPU上
直接在GPU上加载:
pretrain = torch.load(opt.pretrain_path)
model.load_state_dict(pretrain['state_dict'])
将GPU模型加载在CPU上:
pretrain = torch.load(opt.pretrain_path, map_location=lambda storage, loc: storage)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in pretrain.items():
if k=='state_dict':
stat