pytorch保存模型的两种方法


前言

模型的本质是一堆用某种结构存储起来的参数
用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要用到的时候加载一下直接使用。
保存的时候有两种方式:
一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存,但内存吗嘛,不是什么大问题,我遇到的模型一般不超过100M。这都是很大的了;
另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。(也可以,也挺方便的)


一、保存整个模型

#保存
torch.save(the_model, PATH)
#读取
model = torch.load(PATH)

读取时不需要先定义model,比如:model=resnet50()。直接加载赋值就行。


二、只保存参数

保存参数:

torch.save(model.state_dict(),path)

读取模型:

# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)

1、加载参数
2、实例化模型
3、将参数赋予模型
也可以在定义模型后直接

new_m.load_state_dict(torch.load('rnn.pt'))

模型不同后缀名的区别

经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已)。
在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。


总结

持续更新

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Brandy_Whisky

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值