8145v5 参数_一文梳理pytorch保存和重载模型参数攻略

87222b6027c0dfb7426081cfd0289f3d.png

训练过程中保存模型参数,就不怕断电了——沃资基·索德

在训练完成之前,我们需要每隔一段时间保存模型当前参数值,一方面可以防止断电重跑,另一方面可以观察不同迭代次数模型的表现;在训练完成以后,我们需要保存模型参数值用于后续的测试过程。所以,保存的对象包含网络参数值、优化器参数值、epoch值等等。

一、定义一个容易识别的网络

在正式介绍模型的保存和加载之前,我们首先定义一个基本的网络Net,它只包含一个全连接层:

class 

将全连接的权重w和偏差b分别设置为10和1,全连接的计算方式如下:

602d9cdab580d22767e33aba846ff64d.png

假设输入x=1,可以知道y值为11:

95d32924c3fb154488393c14d6fe4dee.png

测试一下输出是不是11,代码如下:

x 

输出:tensor([[11.]], grad_fn=<AddmmBackward>),说明上述计算是正确的。不采用参数随机初始化,而是用特殊的数值初始化,是因为我们希望重载模型的时候,能够从特殊数值一眼判断出保存和重载过程是否正确,也可以把权重设置为一张图片数值,然后判断加载的参数值能不能恢复原图。

二、保存Net的参数值

保存模型参数之前,需要知道Net的参数值存储在其state_dict(状态字典)属性中,我们查看一下net的state_dict包含哪些参数:

print

我们将会得到net包含的所有参数名称与参数值

09e4049c8b0a489684abcd922757f29b.png

包含一个weight和一个bias,对应的值分别是10和1,和我们之前定义的全连接层一致。我们需要保存的就是这个state_dict,保存的函数为“torch.save()”,参数是我们需要保存的dict和存储路径

torch.save(obj=net.state_dict(), f="models/net.pth")

现在,同级目录models下将会出现net.pth文件,pth文件中的内容就是net的参数名称和值对应的state_dict,如下:

26a81c87fb697382e49ebb35505f6d53.png

三、加载Net参数值并用于新的模型

最后一个步骤就是从pth文件中重新获取Net参数值,并把参数值装载到新定义的Model对象中。这里我们重新定义一个结构和Net类相同的类Model,区别仅仅是Model参数初始值和Net不同,代码如下:

class 

这里将Model的初始值权重w和偏差都设置为0,查看其state_dict:

model 

得到的w和b值与预期相同,均为0,如下:

f22480f2560641e61d45b5ada5d4155a.png

现在,我们将model对象的参数值设置为net.pth中的值,需要使用“model.load_state_dict()”函数重置model的参数值为"torch.load(models/ net.pth)"中的参数值,如下:

model

至此,model的w和b值就不再是0了,而是net中w和b对应的10和1,如下:

09e4049c8b0a489684abcd922757f29b.png

其中参数值重载的核心函数为“model.load_state_dict()”,每个继承自nn.Module的网络都能通过这个函数设定参数值。

四、优化器与epoch的保存

保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先“torch.save()”再“torch.load_state_dict()”,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:

net = Net()
Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch = 96

现在,创建一个字典来保存所有的对象,并用save函数保存这个字典

all_states 

所有的对象都被保存到models文件夹下了:

417357da9da8aec7b13b94d93100eea6.png

可以使用load()函数把所有的对象再次提取出来:

reload_states 

得到的所有参数如下:

3141624210a68ac529b58061aeff805a.png

五、总结

pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。

参考:

https://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

https://blog.csdn.net/Code_Mart/article/details/88254444

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值