Pytorch保存和加载模型(load和load_state_dict)

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。
首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
先举最简单的例子:

import torch

model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape

state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)

torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

pth文件既可能保存了模型的图结构,也有可能没保存;
加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

  • 实验

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'only_weights.pth')

model_state_dict = torch.load('only_weights.pth')
model1 = models.vgg16() # describe the graph shape
model1.load_state_dict(model_state_dict)
model1.eval()

torch.save(model1, 'whole_model.pth')

model2 = torch.load('whole_model.pth')
model2.eval()

# model3 = torch.load('only_weights.pth')
# model3.eval()    # Error


model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

  • torch.load_state_dict()函数的用法

在Pytorch中构建好一个模型后,一般需要进行预训练权重中加载。torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中,操作方式如下所示:

sd_net = torchvision.models.resnte50(pretrained=False)
sd_net.load_state_dict(torch.load('*.pth'), strict=True)

在本博文中重点关注的是 属性 strict; 当strict=True,要求预训练权重层数的键值与新构建的模型中的权重层数名称完全吻合;如果新构建的模型在层数上进行了部分微调,则上述代码就会报错:说key对应不上。

此时,如果我们采用strict=False 就能够完美的解决这个问题。也即,与训练权重中与新构建网络中匹配层的键值就进行使用,没有的就默认初始化

参考博文:
https://blog.csdn.net/ChaoMartin/article/details/118686268

https://blog.csdn.net/leviopku/article/details/123925804

  • 9
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值