pytorch保存模型的两种方法


前言

模型的本质是一堆用某种结构存储起来的参数
用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要用到的时候加载一下直接使用。
保存的时候有两种方式:
一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存,但内存吗嘛,不是什么大问题,我遇到的模型一般不超过100M。这都是很大的了;
另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。(也可以,也挺方便的)


一、保存整个模型

#保存
torch.save(the_model, PATH)
#读取
model = torch.load(PATH)

读取时不需要先定义model,比如:model=resnet50()。直接加载赋值就行。


二、只保存参数

保存参数:

torch.save(model.state_dict(),path)

读取模型:

# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)

1、加载参数
2、实例化模型
3、将参数赋予模型
也可以在定义模型后直接

new_m.load_state_dict(torch.load('rnn.pt'))

模型不同后缀名的区别

经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已)。
在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。


总结

持续更新

  • 10
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
PyTorch 中,保存模型两种主要的方式:保存整个模型保存模型的参数。下面分别介绍这两种方式的实现方法。 1. 保存整个模型 保存整个模型时,需要将模型的结构和参数都保存下来。这可以通过使用 PyTorch 的 `torch.save()` 函数来实现。 ```python # 定义模型 model = MyModel() # 保存整个模型 torch.save(model, 'model.pth') ``` 上面代码中,`MyModel()` 是自定义的模型类,`model.pth` 是保存整个模型的文件名。 加载整个模型时,可以使用 PyTorch 的 `torch.load()` 函数来加载模型。 ```python # 加载整个模型 model = torch.load('model.pth') ``` 注意,加载模型时需要保证相应的模型代码已经被定义。 2. 保存模型的参数 保存模型的参数时,只需要将模型的参数保存下来,而不需要保存模型的结构。这可以通过使用 PyTorch 的 `state_dict()` 函数来实现。 ```python # 定义模型 model = MyModel() # 保存模型的参数 torch.save(model.state_dict(), 'model_params.pth') ``` 上面代码中,`model.state_dict()` 返回模型的参数字典,`model_params.pth` 是保存模型参数的文件名。 加载模型参数时,需要先定义模型并加载相应的参数。 ```python # 定义模型 model = MyModel() # 加载模型的参数 model.load_state_dict(torch.load('model_params.pth')) ``` 注意,加载模型参数时需要保证相应的模型代码已经被定义,并且模型的结构要与保存参数时的模型结构相同。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Brandy_Whisky

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值