保存eager model(模型和模型参数) 或者模型参数
import torch
# 加载模型
model = MyModule()
# 保存模型参数
torch.save(model.state_dict(), './model/Test/model.pth')
# 加载模型参数
state_dict = torch.load('./model/Test/model.pth')
model = model.load_state_dict(state_dict, strict=False)
import torch
# 保存模型
torch.save(model, './model/Test/model.pth')
# 加载模型
model = torch.load('./model/Test/model.pth')
torch.jit.save 和 torch.jit.load
torch.jit.save 用于保存使用 torch.jit.script 或 torch.jit.trace 转换后的模块对象。
import torch
torch.jit.save(traced_model, './model/Test/traced_model.pth')
torch.jit.load 用于加载使用 torch.jit.script 或 torch.jit.trace 转换后的模块对象。
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)
torch.jit.script 与torch.jit.trace
torch.jit.script 和torch.jit.trace是PyTorch 中用于将模型转换为脚本或跟踪模型执行的工具。它们是 PyTorch 的即时编译(Just-in-Time Compilation)模块的一部分,用于提高模型的执行效率并支持模型的部署。
torch.jit.trace(首选该导出方式)
跟踪模型可以看作是一个具有相同功能的脚本模型,但它还保留了原始模型的动态特性,可以使用更多高级特性,如动态图和控制流。(由于 torch.jit.trace 方法只跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。)
- torch.jit.trace是一种模型导出方法;该方式可以“跟踪/记录”所有执行到图形中的操作。
- 在模型内部的数据类型只有张量,且没有for if while等控制流(或者只有静态控制流),选择torch.jit.trace
- 支持python的预处理和动态行为;
- torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
- 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;
import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self, return_b=False):
super().__init__()
self.return_b = return_b
def forward(self, x):
a = x + 2
if self.return_b: #属于静态控制
b = x + 3
return a, b
return a
model = MyModule(return_b=True)
# Will work 成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))
# Will fail 失败
scripted = torch.jit.script(model)
torch.jit.script(必要时,或想在C++上用torch模型时(因为Torch脚本程序可以在其他语言的程序中运行))
- torch.jit.script是一种模型导出方法;其是编译python的模型源码得到可执行的图,即将模型转换为脚本 。
- 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.script。
- 不支持python的预处理和动态行为;
- 必须做下类型标注。
- torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = torch.nn.functional.relu(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
model = MyModel()
print(model)
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')
# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)
load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
结果输出:
MyModel(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
original_name=MyModel
(conv): RecursiveScriptModule(original_name=Conv2d)
(fc): RecursiveScriptModule(original_name=Linear)
)