TorchScript的TracedModule和ScriptModule的区别

本文深入解析TorchScript的TracedModule与ScriptModule的区别,通过对比官方教程与ReneWang文章,详细阐述了ScriptModule如何追踪到ifelse控制流,为读者提供了清晰的理解。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近在读TorchScript的入门介绍,看了官方链接的文章,然后感觉是云山雾罩,不知所云。

然后搜索到了Rene Wang的文章,才感觉明白了好多。

官方的介绍TracedModule的缺点例子是这样的:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)

输出是:

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(input, ), )
  _1 = torch.tanh(torch.add(_0, h, alpha=1))
  return (_1, _1)

然后官方再介绍ScriptMoudle:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

然后输出是:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

然后文章里就高潮叫hooray了,我还是一脸懵逼的,根本没有看到ScriptModule的code与TracedModule的code差异啊?

Rene Wang的文章解释的很到位,关键要看my_cell.dg.code,其实他们是这样的

traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)

--输出--
c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  after removing the cwd from sys.path.

def forward(self,
    x: Tensor) -> Tensor:
  return x
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code)
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
print(traced_cell.code)
#只有从dg.code才能看到 if else 流程控制语句执行了
print(traced_cell.dg.code)

--输出--

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

RecursiveScriptModule(
  original_name=MyCell
  (dg): RecursiveScriptModule(original_name=MyDecisionGate)
  (linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

这样能够清晰的看到ScriptModule追踪到了if else 控制流。

基于torch 1.4.0版本,可能官方的tutorial是基于老的版本的实例。

<think>好的,用户现在问的是TorchScript是什么,我需要先回想一下之前关于Stable Baselines3的讨论,确保这次回答保持一致的风格结构。用户之前可能对PyTorch相关的内容感兴趣,因为Stable Baselines3是基于PyTorch的。 首先,我应该明确TorchScript的定义。TorchScript是PyTorch的一个特性,允许将模型转换为可序列化优化的代码,可以在非Python环境中运行。用户可能想知道它的用途优势。 接下来,要分点说明核心功能。比如,独立于Python运行环境、性能优化、模型序列化、跨平台部署等。这些点能帮助用户理解为什么需要使用TorchScript。 然后,需要解释使用场景,比如生产部署、移动端应用、多线程高性能需求等。用户可能关心实际应用中的好处,比如脱离Python依赖,提升速度。 基本使用步骤也很重要。用户可能需要一个简单的指南,如何将模型转换成TorchScript,包括跟踪脚本化两种方法,以及如何保存加载模型。提供代码示例会更有帮助,但要注意用户的技术水平,保持简洁明了。 还要提到注意事项,比如动态控制流的限制,调试困难,输入类型约束等,避免用户在实际使用中遇到问题。 最后,对比TorchScript与普通PyTorch模型的区别,突出其优势,比如部署能力性能优化。这样用户能更清楚何时选择使用TorchScript。 需要确保回答结构清晰,分点明确,语言简洁,符合之前用户可见层的指导要求。同时,检查是否所有数学公式代码都符合格式要求,比如行内公式用$...$,独立公式用$$...$$,但在这个问题中可能涉及较少数学内容,主要关注代码示例。 可能需要回顾之前的回答结构,比如核心功能、使用步骤、注意事项、对比等,保持一致性。确保用中文,避免术语错误,比如正确翻译术语如“序列化”、“优化”等。 用户可能有后续问题,比如如何调试TorchScript模型,或具体部署步骤,但当前只需回答TorchScript是什么,不需要展开太多。保持回答全面但不过于冗长。</think>TorchScript 是 PyTorch 提供的一种将动态图模型(PyTorch 默认的即时执行模式)转换为**静态图模型**的工具,旨在实现模型的高效部署跨平台运行。通过 TorchScript,PyTorch 模型可以脱离 Python 环境,在 C++、移动端或其他高性能场景中执行。 --- ### 一、核心功能 1. **独立于 Python 运行环境** 将 PyTorch 模型转换为 TorchScript 后,可在不依赖 Python 解释器的环境中运行(如 C++ 服务端、移动设备)。 2. **性能优化** 静态图模型可通过编译器优化(如算子融合、内存预分配)提升推理速度,尤其适合生产环境的高吞吐需求。 3. **模型序列化** 支持将模型保存为 `.pt` 或 `.pth` 文件,便于存储跨平台共享。 4. **跨平台部署** 兼容多种硬件架构(CPU、GPU)操作系统(Linux、Windows、Android、iOS)。 --- ### 二、使用场景 - **生产环境部署**:脱离 Python 依赖,直接集成到 C++ 服务中。 - **移动端/嵌入式设备**:在资源受限的环境中运行深度学习模型。 - **高性能推理**:通过静态图优化减少运行时开销。 - **多线程支持**:在 C++ 中利用多线程加速模型推理。 --- ### 三、基本使用步骤 #### 1. **将 PyTorch 模型转换为 TorchScript** 两种主要方法: - **跟踪(Tracing)**:通过输入示例数据记录模型执行流程。 ```python import torch model = MyModel() # 自定义的 PyTorch 模型 example_input = torch.rand(1, 3, 224, 224) traced_model = torch.jit.trace(model, example_input) ``` 适用于无复杂控制流(如条件分支、循环)的模型。 - **脚本化(Scripting)**:直接解析模型代码生成静态图。 ```python scripted_model = torch.jit.script(model) ``` 适用于包含动态控制流的模型(需确保代码符合 TorchScript 语法)。 #### 2. **保存与加载模型** ```python # 保存模型 traced_model.save("model.pt") # 加载模型(在 Python 或 C++ 中) loaded_model = torch.jit.load("model.pt") ``` #### 3. **在 C++ 中调用模型** ```cpp #include <torch/script.h> torch::jit::script::Module model = torch::jit::load("model.pt"); std::vector<torch::jit::IValue> inputs{torch::ones({1, 3, 224, 224})}; at::Tensor output = model.forward(inputs).toTensor(); ``` --- ### 四、注意事项 1. **动态控制流的限制** 若使用 `torch.jit.trace`,模型中的条件语句或循环可能无法正确记录,需改用 `torch.jit.script`。 2. **调试困难** 静态图模型的错误信息可能不如动态图直观,需逐步验证转换结果。 3. **输入类型约束** TorchScript 对输入张量的类型形状有严格限制,需确保输入与训练时一致。 --- ### 五、TorchScript 与普通 PyTorch 模型的对比 | 特性 | 普通 PyTorch 模型 | TorchScript 模型 | |---------------------|---------------------------|--------------------------------| | **执行模式** | 动态图(即时执行) | 静态图(预编译优化) | | **依赖环境** | 必须依赖 Python | 可脱离 Python 运行 | | **性能** | 灵活性高,但优化空间有限 | 编译优化提升推理速度 | | **部署场景** | 适合研究与实验 | 适合生产环境、移动端嵌入式设备 | --- ### 六、典型应用示例 - 移动端图像分类(如通过 PyTorch Mobile 部署到 Android)。 - 工业级推理服务(如使用 `libtorch` C++ 库构建高性能 API)。 - 边缘计算设备上的实时决策(如自动驾驶中的低延迟推理)。 如果需要进一步了解如何调试 TorchScript 或优化转换流程,可以继续提问! 🚀
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值