Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。各类框架中的模型,通过ONNX进行转化见下图所示:
导出函数及参数介绍:
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False)
模型(torch.nn.Module) - 要导出的模型。
args(参数元组) - 模型的输入,例如,这-model(*args)是模型的有效调用。任何非变量参数将被硬编码到导出的模型中; 任何变量参数都将成为输出模型的输入,按照它们在参数中出现的顺序。如果args是一个变量,这相当于用该变量的一个元组来调用它。(注意:将关键字参数传递给模型目前还不支持,如果需要,给我们留言。)
f - 类文件对象(必须实现返回文件描述符的fileno)或包含文件名的字符串。一个二进制Protobuf将被写入这个文件。
export_params(布尔,默认为True) - 如果指定,所有参数将被导出。如果要导出未经训练的模型,请将其设置为False。在这种情况下,导出的模型将首先将其所有参数作为参数,按照指定的顺序model.state_dict().values()
verbose(布尔,默认为False) - 如果指定,我们将打印出一个调试描述的导出轨迹。
training(布尔,默认为False) - 在训练模式下导出模型。目前,ONNX只是为了推导出口模型,所以你通常不需要将其设置为True。
深度学习模型实际上就是一个计算图。模型部署时通常把模型转换成静态的计算图,即没有控制流(分支语句、循环语句)的计算图。
注意,torch.onnx.export中需要的模型实际上是一个torch.jit.ScriptModule。而要把普通 PyTorch 模型转一个这样的 TorchScript 模型,有跟踪(trace)和记录(script)两种导出计算图的方法。如果给torch.onnx.export传入了一个普通 PyTorch 模型(torch.nn.Module),那么这个模型会默认使用跟踪的方法导出。(记录法需使用model_script = torch.jit.script(model)将模型预先做简化 )
跟踪法和记录法的区别在于:
跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);记录法则能通过解析模型来正确记录所有的控制流。
此外,对于网络输入输出可能出现的动态情况,应设置动态维度:
其中,将输入输出的某些维度设置为动态维度(默认全是静态)。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x)
return x
model = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']
dynamic_axes_0 = {
'in' : [0],
'out' : [0]
}
dynamic_axes_23 = {
'in' : [2, 3],
'out' : [2, 3]
}
torch.onnx.export(model, dummy_input, model_names[0],
input_names=['in'], output_names=['out'])
torch.onnx.export(model, dummy_input, model_names[1],
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0)
torch.onnx.export(model, dummy_input, model_names[2],
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23)
使模型在 ONNX 转换时有不同的行为:
torch.onnx.is_in_onnx_export()标识符仅在模型导出时为真(如将后处理步骤也引入torch.nn.Module模型中),如:
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x)
if torch.onnx.is_in_onnx_export():
#数值限制在[0, 1]之间
x = torch.clip(x, 0, 1)
return x
下面的代码可验证生成的onnx格式是否正确:
import onnx
onnx_model = onnx.load("srcnn.onnx")
try:
onnx.checker.check_model(onnx_model)
except Exception:
print("Model incorrect")
else:
print("Model correct")
ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,即”推理引擎“。 ONNX Runtime 可以直接读取并运行 .onnx 文件。
以下是python接口示例代码:
#model代码略
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
torch.onnx.export(
model,
x,
"srcnn.onnx",
opset_version=11,
input_names=['input'],
output_names=['output'])
import onnxruntime
#onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
#其第一个参数为输出张量名的列表,第二个参数为输入值的字典
ort_output = ort_session.run(['output'], ort_inputs)[0]
ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
注意,转换为onnx时,输入参数必须为tensor,输入参数及模型必须做相应处理:
#具体类节选:
def forward(self, x, upscale_factor):
#使用.item(),将tensor变量转为普通python变量
x = interpolate(x,
scale_factor=upscale_factor.item(),
mode='bicubic',
align_corners=False)
...
#使用时
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy()
...
with torch.no_grad():
#将常数3改为torch.tensor(3)
torch.onnx.export(model, (x, torch.tensor(3)),
"srcnn2.onnx",
opset_version=11,
input_names=['input', 'factor'],
output_names=['output'])
但是注意,这样做,使upscale_factor.item(),在追踪时,无法追到普通python变量,使得第二个参数,仍被以常量3固定下来。
使用自定义算子:替换interpolate
class NewInterpolate(torch.autograd.Function):
@staticmethod
def symbolic(g, input, scales):
return g.op("Resize",
input,
g.op("Constant",
value_t=torch.tensor([], dtype=torch.float32)),
scales,
coordinate_transformation_mode_s="pytorch_half_pixel",
cubic_coeff_a_f=-0.75,
mode_s='cubic',
nearest_mode_s="floor")
@staticmethod
def forward(ctx, input, scales):
scales = scales.tolist()[-2:]
return interpolate(input,
scale_factor=scales,
mode='bicubic',
align_corners=False)