torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。
例子1:
模型保存时,同时保存了训练的其他信息;
torch.save({
'model': model.state_dict(),
'classes': classes,
'args': args},
os.path.join(args.checkpoints, 'model_{}_{}.pth'.format(epoch, i)))
modelPath ='static/model_56_300.pth'
pretrained_model = models.resnext101_32x8d(pretrained=True)
IN_FEATURES = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(IN_FEATURES, 8)
model = pretrained_model
state_dict = torch.load(modelPath, map_location='cpu')
new_state_dict = OrderedDict()
####state_dict['model'] 获取模型权重
#### 多卡训练,移除名称前面的module.
for k, v in state_dict['model'].items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict,strict=True)
# model.load_state_dict(state_dict,strict=False)
# model =state_dict['model']
model = model.eval()
torch.save(model.state_dict(), PATH)
当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用torch.save()函数来保存模型state_dict,它会给模型恢复提供 最大的灵活性,这就是为什么要推荐它来保存的原因。
在运行推理之前,务必调用model.eval()去设置 dropout 和 batch normalization 层为评估模式。
torch.save(model, PATH)
例子2:
torch.save(model.state_dict(), './models/model_train_NSFW_viskit_violence_20211117'+str(epoch)+'.pth')
modelPath ='static/model_train_NSFW_viskit_violence_20211117.pth'
pretrained_model = models.resnext101_32x8d(pretrained=True)
IN_FEATURES = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(IN_FEATURES, 8)
model = pretrained_model
state_dict = torch.load(modelPath, map_location='cpu')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict,strict=True)
# model.load_state_dict(state_dict,strict=False)
# model =state_dict['model']
model = model.eval()