保存和加载模型

这篇博客介绍了如何使用PyTorch的torch.save和torch.load函数保存和加载模型,包括模型状态字典的处理,以及在多卡训练后的模型加载。在推理阶段,模型需先调用model.eval()进入评估模式。示例中展示了如何处理预训练模型的权重加载,并提供了两个具体的保存和加载模型的例子。
摘要由CSDN通过智能技术生成

torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

例子1:
模型保存时,同时保存了训练的其他信息;

torch.save({
                        'model': model.state_dict(),
                        'classes': classes,
                        'args': args},
                        os.path.join(args.checkpoints, 'model_{}_{}.pth'.format(epoch, i)))

modelPath ='static/model_56_300.pth'
pretrained_model = models.resnext101_32x8d(pretrained=True)
IN_FEATURES = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(IN_FEATURES, 8)
model = pretrained_model

state_dict = torch.load(modelPath, map_location='cpu')

new_state_dict = OrderedDict()
####state_dict['model'] 获取模型权重
#### 多卡训练,移除名称前面的module.
for k, v in state_dict['model'].items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict,strict=True)
# model.load_state_dict(state_dict,strict=False)
# model =state_dict['model']
model = model.eval()
torch.save(model.state_dict(), PATH)

当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用torch.save()函数来保存模型state_dict,它会给模型恢复提供 最大的灵活性,这就是为什么要推荐它来保存的原因。

在运行推理之前,务必调用model.eval()去设置 dropout 和 batch normalization 层为评估模式。

torch.save(model, PATH)

例子2:

torch.save(model.state_dict(), './models/model_train_NSFW_viskit_violence_20211117'+str(epoch)+'.pth')
modelPath ='static/model_train_NSFW_viskit_violence_20211117.pth'
pretrained_model = models.resnext101_32x8d(pretrained=True)
IN_FEATURES = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(IN_FEATURES, 8)
model = pretrained_model
state_dict = torch.load(modelPath, map_location='cpu')

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict,strict=True)
# model.load_state_dict(state_dict,strict=False)
# model =state_dict['model']
model = model.eval()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值