pytorch学习001- -如何保存模型

保存和加载模型

只保存模型的参数

保存

torch.save(model.state_dict(),'xxx.pth')

加载

model = net()	#首先要先定义网络模型
state_dict = torch.load('xxx.pth')	# 读取pth文件中的参数
model.load_state_dict(state_dict['model'])	#将参数导入模型

这种方法操作比较麻烦,但是比较节省内存。

official example

class MyModule(torch.nn.Module):


m = MyModule()
m.state_dict()

torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)

附加保存一些其他信息

torch.save({
             'epoch': epoch + opt.start,
            'model_state_dict': model.state_dict()
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': epoch_losses.avg},
        }, os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch + opt.start)))

保存整个模型

保存

torch.save(net, 'xxx.pt')

加载

test = torch.load('xxx.pt')	#注意其中pt文件的路径

这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。

官方文档

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>