Pytorch保存模型

前提

模型训练好后自然想要将里面所有层涉及的权重保存下来,这样子我们的模型就能部署在任意有pytorch环境下了。

Torch.save/load

先介绍一下纯py接口的保存方式。

class my_model(nn.Modules):
	def __init__(self):
		super(my_module,self).__init__()
		self.relu = nn.Relu()
	def forward(x, self):
		return self.relu(x)
		......
		......
model = my_model()
torch.save(model.state_dict(),"./model_name.pth")
torch.load(model,"./model_name.pth)

这样就会在当前目录保存一份.pth文件了(里面只保存了这个模型的所有权重即.parameters())。
下面的load就是在其他脚本中使用这个模型预训练好的权重
Torch.save 官网有详细介绍

Torch.jit

这个方式的保存更加高级(保存为TorchScript)可以与Torch c++接口通用。
这样带来的好处就是保存下来的模型为编译过后的运行时不需要python解释器,运行速度会更快。并且这种方式可以连带模型的定义一起保存,无需import model。

一般有两种保存方式

  1. torch.jit.trace
    这种方式为追踪一个函数的执行流,使用时需要提供一个测试输入。
    官网有样例。
    Torch.jit.trace
    需要注意的是这个接口只追踪测试输入走过的函数执行流(如果模型中有多条分支的话只会保存测试输入走过的分支!!!!!),所以对于一些多分支的模型不要采用这种方式,采用下面的Torch.jit.script。比如model.eval()和model.train()可以控制模型内BN层和dropout的权重是否固定,如果采用这种方式只能保留其中之一状态(固定或不固定)。
  2. torch.jit.script
    使用这种方式可以将一个模型完整的保存下来,和上面的trace正好相对。如果模型中的分支很多,并且在运行时会改变的话一定要用这种形式保存。
    Torch.jit.trace
    这里简单写一下如何使用:(我这个测试模型里面只有一个简单的relu所以随便输入一个Tensor就行了)
store = torch.jit.trace(model,torch.randn(1,2,3,dtype=torch.float32))
store  = torch.jit.script(model)
torch.jit.save(store,"./model_name.pth")
torch.jit.load("./model_name.pth")
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值