pytorch script traced model 以及 模型保存 详解

保存eager model(模型和模型参数) 或者模型参数

import torch
# 加载模型
model = MyModule()
# 保存模型参数
torch.save(model.state_dict(), './model/Test/model.pth')

# 加载模型参数
state_dict = torch.load('./model/Test/model.pth')
model = model.load_state_dict(state_dict, strict=False)

import torch
# 保存模型
torch.save(model, './model/Test/model.pth')
# 加载模型
model = torch.load('./model/Test/model.pth')

torch.jit.save 和 torch.jit.load

torch.jit.save 用于保存使用 torch.jit.script 或 torch.jit.trace 转换后的模块对象。

import torch
torch.jit.save(traced_model, './model/Test/traced_model.pth')

torch.jit.load 用于加载使用 torch.jit.script 或 torch.jit.trace 转换后的模块对象。

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)

torch.jit.script 与torch.jit.trace

torch.jit.script 和torch.jit.trace是PyTorch 中用于将模型转换为脚本或跟踪模型执行的工具。它们是 PyTorch 的即时编译(Just-in-Time Compilation)模块的一部分,用于提高模型的执行效率并支持模型的部署。

torch.jit.trace(首选该导出方式)

跟踪模型可以看作是一个具有相同功能的脚本模型,但它还保留了原始模型的动态特性,可以使用更多高级特性,如动态图和控制流。(由于 torch.jit.trace 方法只跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。)

  1. torch.jit.trace是一种模型导出方法;该方式可以“跟踪/记录”所有执行到图形中的操作。
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流(或者只有静态控制流),选择torch.jit.trace
  3. 支持python的预处理和动态行为;
  4. torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
  5. 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;
import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, return_b=False):
        super().__init__()
        self.return_b = return_b

    def forward(self, x):
        a = x + 2
        if self.return_b:  #属于静态控制
            b = x + 3
            return a, b
        return a

model = MyModule(return_b=True)

# Will work  成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))
# Will fail 失败
scripted = torch.jit.script(model)

torch.jit.script(必要时,或想在C++上用torch模型时(因为Torch脚本程序可以在其他语言的程序中运行))

  1. torch.jit.script是一种模型导出方法;其是编译python的模型源码得到可执行的图,即将模型转换为脚本 。
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.script。
  3. 不支持python的预处理和动态行为;
  4. 必须做下类型标注。
  5. torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。
import torch


# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x


model = MyModel()
print(model)

# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))

# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))

# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')

# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)

load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)

结果输出:

MyModel(
  (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值