【官方教程】TorchScript简介(introduction to torchscript)

目录

1.PyTorch模型基础

2.TorchScript的基础

2.1 跟踪(Tracing)模块

3.使用脚本来转换模块

3.1 混合脚本(Scripting)和跟踪(Tracing)

4.保存和加载模型

进一步阅读


本文翻译自官方文档:https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html。本文的输出结果为我的实测结果,可能与官方手册不同。

本教程是对TorchScript的简介,TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以在高性能环境(例如C ++)中运行。

在本教程中,我们将介绍:

1、PyTorch中的模型编写基础,包括:

  • Modules(模块)
  • 定义forward函数
  • 将模块组成模块的层次结构

2、将PyTorch模块转换为TorchScript(我们的高性能部署runtime)的特定方法

  • 跟踪现有模块
  • 使用脚本直接编译模块
  • 如何组合这两种方法
  • 保存和加载TorchScript模块

我们希望在完成本教程之后,您将继续阅读后续教程,该教程将引导您真正地从C ++调用TorchScript模型的示例。

import torch  # 这是同时使用PyTorch和TorchScript的全部所需!
print(torch.__version__)

输出结果

1.5.1

1.PyTorch模型基础

让我们开始定义一个简单的模块。Module(模块)是PyTorch中组成的基本单位。它包含:

  • 构造函数,为调用准备模块
  • 一组Parameters 和sub-Modules(参数和子模块。这些由构造函数初始化,并且可以在调用期间由模块使用。
  • forward 函数。这是调用模块时运行的代码。
    我们来看一个小例子:
import torch

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h) #tanh()为双曲正切,定义域:R,值域:(-1,1)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

输出结果:

(tensor([[0.8546, 0.7623, 0.2731, 0.8782],
        [0.7811, 0.7509, 0.8009, 0.4137],
        [0.5279, 0.1417, 0.7708, 0.6857]]), tensor([[0.8546, 0.7623, 0.2731, 0.8782],
        [0.7811, 0.7509, 0.8009, 0.4137],
        [0.5279, 0.1417, 0.7708, 0.6857]]))

因此,我们已经:

  1. 创建了一个torch.nn.Module的子类。

  2. 定义一个构造函数。构造函数没有做太多事情,只是将构造函数称为super。

  3. 定义了一个forward 函数,该函数需要两个输入并返回两个输出。forward 函数的实际内容并不是很重要,但是它是一种伪造的RNN单元——即,该函数应用于循环。

我们实例化了该模块,并制作了xy,它们只是3x4的随机值矩阵。 然后,我们使用my_cell(x,h)调用该cell。 这又调用了我们的forward 函数。

让我们做一些更有趣的事情:

import torch

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell)
print(my_cell(x, h))

输出结果:

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[0.7895, 0.3672, 0.7823, 0.6751],
        [0.3382, 0.5957, 0.7047, 0.7534],
        [0.7783, 0.5605, 0.6458, 0.3532]], grad_fn=<TanhBackward>), tensor([[0.7895, 0.3672, 0.7823, 0.6751],
        [0.3382, 0.5957, 0.7047, 0.7534],
        [0.7783, 0.5605, 0.6458, 0.3532]], grad_fn=<TanhBackward>))

我们已经重新定义了模块MyCell,但是这次我们添加了self.linear属性,并在forward函数中调用了self.linear

这里到底发生了什么? torch.nn.Linear是PyTorch标准库中的Module 。就像MyCell一样,可以使用调用语法来调用它。我们正在建立模块的层次结构。

在模块上打印可以直观地表示该模块的子类层次结构。在我们的示例中,我们可以看到我们的线性子类及其参数。

通过以这种方式组合模块,我们可以简洁而易读地编写具有可重用组件的模型。

您可能已经在输出中注意到grad_fn。这是PyTorch的自动微分方法(称为autograd)的详细信息。简而言之,该系统允许我们通过潜在的复杂程序来计算导数。该设计为模型创作提供了极大的灵活性。

现在,让我们检查一下它的灵活性:

import torch

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值