直接在GPU上加载:
pretrain = torch.load(opt.pretrain_path)
model.load_state_dict(pretrain['state_dict'])
将GPU模型加载在CPU上:
pretrain = torch.load(opt.pretrain_path, map_location=lambda storage, loc: storage)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in pretrain.items():
if k=='state_dict':
state_dict=OrderedDict()
for keys in v:
name = keys[7:]# remove `module.`
state_dict[name] = v[keys]
new_state_dict[k]=state_dict
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict['state_dict'])