科研小白进阶之路——torch实现模型保存和加载的几种方式

模型的保存:

方式一:保存整个模型

model = resnet()
#-------省略训练过程------------
torch.save(model,path)

方式二:只保存模型参数

model = resnet()
#-------省略训练过程---------------
torch.save(model.state_dict,path)

模型的加载:
情况一:训练已经结束,测试时加载训练好的模型
1、在训练的py文件下面直接对模型进行测试:

model = resnet()     # 定义模型
#----训练过程省略-----
torch.save(model,path) #保存模型
#----训练好的模型进行测试----
model.eval
#---加载测试集等步骤

2、单独写一个infer.py
①加载整个模型和参数,对应模型的保存方式一,即当时就保存了整个模型,此时直接加载,然后去测试;

model_path = './model_000009.pth'
model = torch.load(model_path)   # 加载了整个模型

②只加载模型参数,对应模型保存的方式二,当时只保存了模型参数,此时需要先初始化模型,然后让模型加载训练好的模型的参数;

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

情况二:训练一半停了,想继续训练
只加载模型参数,网络结构从代码中创建

net = resnet()  #代码中创建网络结构
path = './pth'
saved_state_dict = torch.load(path) #加载模型参数
net.load_state_dict(saved_state_dict) #应用到网络中

情况三:加载预训练模型
1、使用torchvision自带的预训练模型

#----------resnet系列------------
import torchvision.models as models
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
#---------VGG系列-------------
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
model = models.vgg16_bn(pretrained=True)
#---------alexnet-----------
Alexnet = models.alexnet()
#--------squeezenet--------
squeezenet = models.squeezenet1_0()

2、如果只需要加载预训练模型的部分参数&#

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值