pytorch如何保存和加载模型

本文介绍了PyTorch中如何从本地加载预训练模型,进行迁移学习时的加载策略,以及模型的保存和加载操作。通过示例代码展示了如何筛选不同结构的模型参数,以及如何在加载时固定部分参数。
摘要由CSDN通过智能技术生成

从本地地址直接加载预训练模型

在本地,有几种方式可以避免下载直接加载预训练模型:
https://www.cnblogs.com/ywheunji/p/10605614.html

  1. 直接修改源码,改为本地地址
  2. 把模型权重下载至torch的缓存文件夹

当需要迁移学习时加载预训练模型

此时,我的模型和预训练的模型有些结构可能不一致,需要多出一步,筛选掉预训练模型中与我的结构不一致的部分:
https://blog.csdn.net/VictoriaW/article/details/72821329
https://www.94e.cn/info/4270

vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)

torch.load 从文件中加载一个用torch.save()保存的对象。

当我是用自己从网上下载的模型的时候要用torch.load:

        pretrain = torch
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值