目录
3.1 混合脚本(Scripting)和跟踪(Tracing)
本文翻译自官方文档:https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html。本文的输出结果为我的实测结果,可能与官方手册不同。
本教程是对TorchScript的简介,TorchScript是PyTorch模型(nn.Module
的子类)的中间表示,可以在高性能环境(例如C ++)中运行。
在本教程中,我们将介绍:
1、PyTorch中的模型编写基础,包括:
- Modules(模块)
- 定义forward函数
- 将模块组成模块的层次结构
2、将PyTorch模块转换为TorchScript(我们的高性能部署runtime)的特定方法
- 跟踪现有模块
- 使用脚本直接编译模块
- 如何组合这两种方法
- 保存和加载TorchScript模块
我们希望在完成本教程之后,您将继续阅读后续教程,该教程将引导您真正地从C ++调用TorchScript模型的示例。
import torch # 这是同时使用PyTorch和TorchScript的全部所需!
print(torch.__version__)
输出结果
1.5.1
1.PyTorch模型基础
让我们开始定义一个简单的模块。Module
(模块)是PyTorch中组成的基本单位。它包含:
- 构造函数,为调用准备模块
- 一组
Parameters
和sub-Modules(
参数和子模块)
。这些由构造函数初始化,并且可以在调用期间由模块使用。 forward
函数。这是调用模块时运行的代码。
我们来看一个小例子:
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
def forward(self, x, h):
new_h = torch.tanh(x + h) #tanh()为双曲正切,定义域:R,值域:(-1,1)
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
输出结果:
(tensor([[0.8546, 0.7623, 0.2731, 0.8782],
[0.7811, 0.7509, 0.8009, 0.4137],
[0.5279, 0.1417, 0.7708, 0.6857]]), tensor([[0.8546, 0.7623, 0.2731, 0.8782],
[0.7811, 0.7509, 0.8009, 0.4137],
[0.5279, 0.1417, 0.7708, 0.6857]]))
因此,我们已经:
-
创建了一个
torch.nn.Module
的子类。 -
定义一个构造函数。构造函数没有做太多事情,只是将构造函数称为super。
-
定义了一个
forward
函数,该函数需要两个输入并返回两个输出。forward
函数的实际内容并不是很重要,但是它是一种伪造的RNN单元——即,该函数应用于循环。
我们实例化了该模块,并制作了x
和y
,它们只是3x4的随机值矩阵。 然后,我们使用my_cell(x,h)
调用该cell。 这又调用了我们的forward
函数。
让我们做一些更有趣的事情:
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell)
print(my_cell(x, h))
输出结果:
MyCell(
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[0.7895, 0.3672, 0.7823, 0.6751],
[0.3382, 0.5957, 0.7047, 0.7534],
[0.7783, 0.5605, 0.6458, 0.3532]], grad_fn=<TanhBackward>), tensor([[0.7895, 0.3672, 0.7823, 0.6751],
[0.3382, 0.5957, 0.7047, 0.7534],
[0.7783, 0.5605, 0.6458, 0.3532]], grad_fn=<TanhBackward>))
我们已经重新定义了模块MyCell
,但是这次我们添加了self.linear
属性,并在forward函数中调用了self.linear
。
这里到底发生了什么? torch.nn.Linear
是PyTorch标准库中的Module
。就像MyCell
一样,可以使用调用语法来调用它。我们正在建立模块的层次结构。
在模块上打印可以直观地表示该模块的子类层次结构。在我们的示例中,我们可以看到我们的线性子类及其参数。
通过以这种方式组合模块,我们可以简洁而易读地编写具有可重用组件的模型。
您可能已经在输出中注意到grad_fn
。这是PyTorch的自动微分方法(称为autograd)的详细信息。简而言之,该系统允许我们通过潜在的复杂程序来计算导数。该设计为模型创作提供了极大的灵活性。
现在,让我们检查一下它的灵活性:
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()