使用DataParallel保存的模型在保存后载入的时候存在权重不匹配的问题,原因是每层的参数是
会比原来多一个module的设置。
import torch
from collections import OrderedDict#python自带的库
state_dict=torch.load(config.model_save_path+"_best")["model"]#我的权重存储在文件的model字段里面的
'''
state = {
'model': model.state_dict(),
'epoch': i,
'training_loss':np.mean(T_loss_L),
'validation_loss':np.mean(V_loss_L),
'training_acc':np.mean(T_acc_L),
'validation_acc':np.mean(V_acc_L)
}
'''
best_val=torch.load(config.model_save_path+"_best")["validation_loss"]
epoch=torch.load(config.model_save_path+"_best")["epoch"]
print("retraing from %d where val loss is %f"%(epoch,best_val))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module
new_state_dict[name] = v #新字典的key值对应的value一一对应
model.load_state_dict(new_state_dict)
print("reload model!")
有疑惑的话可以打印下state_dict
看看。