TorchScript学习笔记
TorchScript是一种可从python代码中创建序列化模型的方法。可以从python代码中保存,并在非python环境中加载模型。注:TorchScript 主要实现的是在 PyTorch 中表示神经网络模型所需的 Python 功能,并不适用于所有的Python特性。
torch.jit
TorchScript是Pytorch的JIT实现。
JIT ,全称是 Just In Time Compilation(即时编译)。
JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而可以使用C++把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等。(多线程执行和性能原因,一般Python代码并不适合做部署。)
导出模型主要有两种方式,Scripting和Tracing。
1.Tracing方式
tracing是相对简单的方式,输入向量,追踪向量在forward函数的流动来获得模型结构。
必须要有输入。
只适用于比较简单的模型,如果forward函数中有控制流结构,向量一次无法遍历所有的分支,这时就要借助script方式。
torch.jit.trace(func, example_inputs)
torch.jit.trace会将torch::jit::Module 转成 torch::jit::Graph 。
如果trace的是Python 函数,那么返回ScriptFunction , 如果是nn.Module.forward或者nn.Module,返回的就是ScriptModule。 如果trace的时候是eval/train模式,那么返回的ScriptModule就是eval/train模式。
- 例子:trace一个函数
def sigmoid(z):
s = 1 / (1 + 1 / np.exp(z))
return s
x = torch.full((2, 2), 1)
print("x: ",x)
traced_func = torch.jit.trace(sigmoid, x)
print(traced_func.graph)
print(traced_func(x))
输出:
x: tensor([[1, 1],
[1, 1]])
graph(%0 : Long(2:2, 2:1, requires_grad=0, device=cpu)):
%1 : Double(2:2, 2:1, requires_grad=0, device=cpu) = prim::Constant[value= 2.7183 2.7183 2.7183 2.7183 [ CPUDoubleType{
2,2} ]]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%2 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%1) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%3 : Long(requires_grad=0, device=cpu) = prim::Constant[value={
1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%4 : Double(2:2, 2:1, requires_grad