TorchScript学习笔记

TorchScript学习笔记

TorchScript是一种可从python代码中创建序列化模型的方法。可以从python代码中保存,并在非python环境中加载模型。注:TorchScript 主要实现的是在 PyTorch 中表示神经网络模型所需的 Python 功能,并不适用于所有的Python特性。

torch.jit

TorchScript是Pytorch的JIT实现。
JIT ,全称是 Just In Time Compilation(即时编译)。
JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而可以使用C++把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等。(多线程执行和性能原因,一般Python代码并不适合做部署。)
导出模型主要有两种方式,Scripting和Tracing。

1.Tracing方式

tracing是相对简单的方式,输入向量,追踪向量在forward函数的流动来获得模型结构。
必须要有输入。
只适用于比较简单的模型,如果forward函数中有控制流结构,向量一次无法遍历所有的分支,这时就要借助script方式。

torch.jit.trace(func, example_inputs)

torch.jit.trace会将torch::jit::Module 转成 torch::jit::Graph 。
如果trace的是Python 函数,那么返回ScriptFunction , 如果是nn.Module.forward或者nn.Module,返回的就是ScriptModule。 如果trace的时候是eval/train模式,那么返回的ScriptModule就是eval/train模式。

  • 例子:trace一个函数
def sigmoid(z):
    s = 1 / (1 + 1 / np.exp(z))
    return s


x = torch.full((2, 2), 1)
print("x: ",x)
traced_func = torch.jit.trace(sigmoid, x)
print(traced_func.graph)
print(traced_func(x))

输出:

x:  tensor([[1, 1],
        [1, 1]])
graph(%0 : Long(2:2, 2:1, requires_grad=0, device=cpu)):
  %1 : Double(2:2, 2:1, requires_grad=0, device=cpu) = prim::Constant[value= 2.7183  2.7183  2.7183  2.7183 [ CPUDoubleType{
   2,2} ]]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %2 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%1) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %3 : Long(requires_grad=0, device=cpu) = prim::Constant[value={
   1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %4 : Double(2:2, 2:1, requires_grad
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值