PyTorch 模型的保存和加载

1. Pytorch 加载完整模型的参数

  1. 保存加载整个模型
# 保存整个模型
torch.save (model_object, 'model.pk1')
# 加载整个模型 
model = torch.load('model.pkl')
  1. 保存模型的参数 (推荐使用)
# 模型参数保存
torch.save (model_object.state_dict(), 'params.pk1') # 保存的是以字典 key - value pair 形式的数据,每一个参数对应着一个值 state_dict 状态字典 

# 模型参数加载
model_object.load_state_dict(torch.load('params.pkl'))

2.加载部分模型参数

大多数时候我们需要根据我们的任务调节我们的模型,很难保证说我们的模型和我们的任务完全一致,但是预训练模型确实有利于提高任务的准确率,为了结合两者的优点,我们可以使用加载部分预训练模型的方法。


# 0. 先加载进全部的参数存进 pretrained_dict 中
pretrained_dict = t.load ('params.pk1')

#1. load state dict 加载状态字典,就是加载属于这个 model 的状态字典,看看这个 model 里面有哪些变量的 name
model_dict = model.state_dict()

# 2. filter out unnecessary keys 过滤掉不属于这个模型的 key 值,只有当 key 的值在这个模型里面的时候才加载这个 key 其他的 key 值自动放弃使用  相当于两个模型求了一个交集。
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 3. overwrite entries in the existing state dict 更新 model 的数据字典 (更新成为可以加载的数据字典)
model_dict.update(pretrained_dict)
# 4. load the new state dict ()  加载参数,只加载模型中存在的参数,放弃一些不存在的参数,只加载修改过的参数。
model.load_state_dict(model_dict)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值