Pytorch之保存和加载预训练的模型

在深度学习中会用到迁移学习的方法,也就是我们把在其它数据集上训练比较好的model拿到我们的模型上来进行finetune,这样避免了我们重新去花费时间去训练模型,比如vgg16提取图像特征的这个模型,大大节省了我们训练的时间。这个过程我们就涉及到加载预训练的模型,有的时候我们需要加载整个模型,有时候我们需要模型的一个部分,因此在本文中将会对在Pytroch这个框架中如何加载预训练的模型做以阐述。说到加载预训练好的模型那就不得不说我们如何保存训练好的模型,训练一个model,将这个训练的结果保存起来,然后再测试集上进行测试,这个也是本文将要说的地方。因此就分为两个部分:模型的保存和模型的加载。

1、模型的保存

1)仅保存参数

torch.save(model.state_dict(),path)
#model 模型
#path 保存的路径,保存的文件都是以.pkl或者.pth的形式保存的

2)  既保存模型结构也保存参数。

torch.save(model,path)
#model 模型
#path 保存的路径,保存的文件都是以.pkl或者.pth的形式保存的

 

2、模型的加载

1) 加载仅保存参数的模型

model = Model()
model.load_state_dict(torch.load(path))

#测试
res = model(test_data)

2) 加载模型结构和参数信息

model = torch.load(path)

#测试
res = model(test_data)

 

3、加载部分预训练模型的参数

pretrained_model = vgg16(pretrained=Treu)
model = Model()

#将模型的layer和参数形成键值对,k是该层的名称,v是参数
pre_dict = pretrained_model.state_dict()
model_dict = model.state_dict()

#仅将用的的k键值对筛选出来
pre_dic = {k:v for k,v in pre_dict.items() if k in model_dict}

#更新model_dict里面的参数
model_dict.update(pre_dict)

#加载更新之后的参数
model.load_state_dict(model_dict)

参考资料:

https://blog.csdn.net/weixin_41278720/article/details/80759933

总结:在模型的保存中大多数还是仅保存模型的参数,而不保存模型的参数和模型,因为这样会占用的空间比较大。

 

 

 

 

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值