深度学习模型预测及模型的保留

今天这篇文章讨论一下我们进行深度学习时,如何将预测的结果如何转换为对应的标签以及如何将最后的模型进行保存和加载。
预测
为了更清晰地说明整个过程,我们还是以代码来说明一下:

from PIL import Iamge
labels=['cat','fish']
img=Image.open(FileName)
img=transforms(img)
img=img.unsequeeze(0)

prediction=simplenet(img)
prediction=prediction.argmax()
print(labels[prediction])

得到预测的结果很简单,只需要把我们的批次(batch)传入模型。然后要找出有较大概率的类。在这里,可以简单地将张量转换为一个数组,并比较这两个元素,不过通常会有更多元素。PyTorch提供了argmax()函数,这很有用,它会返回张量中最大值得索引。然后使用这个索引访问我们的标签,打印出预测结果。

模型保存
如果你对一个模型的性能很满意,或者由于某个原因需要停止训练,可以使用torch.save()方法采用Python的pickle格式保存模型的当前状态。反过来,我们也可以使用torch.load()方法加载之前保存的一个模型迭代。
所以,保存当前参数和模型结构的代码,如下所示:

torch.save('simplenet','./filedir')

可以使用如下代码进行加载代码:

simplenet=torch.load('./modeldir')

这会把参数以及模型的结构都保存到一个文件中。如果以后某个时间改变了模型的结构,可能就会有问题。由于这个原因,更常见的做法时保存模型的state_dict。这是一个标准的Python dict,其中包含模型中每一层参数的映射。可以保存如下state_dict:

torch.save(model.state_dict(),PATH)

恢复时,首先创建模型的一个实例,再使用load_state_dict。

simplenet=SimpleNet()
simplenet_state_dict=torch.load("./modeldir")
simplenet.load_state_dict(simplenet_state_dict)

这样的好处是,如果以某种方式扩展了模型,可以向load_state_dict提供了一个strict=False参数,为state_dict中确实有的模型层指定相应的参数,而如果所加载的state_dict与模型当前结构相比缺少或增加了某些层,也不会失败。因为这只是一个普通的Python dict 。可以改变键名来适应你的模型,如果要从一个完全不同的模型抽取参数,这会很方便。
:文章摘选自《基于PyTroch的深度学习》 Ian Pointer著

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
深度学习模型预测代码可以使用Python及相关的库来实现。在深度学习模型预测,常用的库包括keras、scikit-learn、pandas和tensorflow。这些库提供了丰富的功能和工具,可以帮助我们构建和训练深度神经网络模型。 在使用深度学习模型进行预测之前,我们需要先准备好数据集,并对数据进行预处理。然后,我们可以使用keras库来构建深度神经网络模型。根据不同的任务和数据类型,可以选择不同的模型结构,如LSTM、GRU、CNN、LSTM-CNN、BiLSTM、Self-Attention、LSTM-Attention、Transformer等。 在构建模型后,我们可以使用优化算法来更新模型的参数值,以使任务的指标表现变好。常用的优化算法包括梯度下降法和随机梯度下降法。通过迭代训练模型,我们可以得到一个“好”的模型。 最后,我们可以使用训练好的模型对新的数据进行预测。通过调用模型预测函数,我们可以得到预测结果。 具体的深度学习模型预测代码可以根据具体的任务和数据集进行编写。可以参考相关的教程和文档,以及使用示例代码来帮助实现深度学习模型预测功能。 #### 引用[.reference_title] - *1* *3* [一文深度学习建模预测全流程(Python)](https://blog.csdn.net/qq_40877422/article/details/121301741)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [【深度学习时间序列预测案例】零基础入门经典深度学习时间序列预测项目实战(附代码+数据集+原理介绍)](https://blog.csdn.net/m0_47256162/article/details/128585814)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

毛毛真nice

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值