import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
import onnx
import onnxruntime
from torch.nn.functional import interpolate
# 模型导出方法:跟踪法和记录法
# 跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);
# 记录法则能通过解析模型来正确记录所有的控制流
class Model(torch.nn.Module):
def __init__(self, n):
super().__init__()
self.n = n
self.conv = torch.nn.Conv2d(3, 3, 3)
# 带循环模型
def forward(self, x):
for i in range(self.n):
x = self.conv(x)
return x
def test_script_and_trace():
models = [Model(2), Model(3)]
model_names = ['model_2', 'model_3']
for model, model_name in zip(models, model_names):
dummy_input = torch.rand(1, 3, 10, 10)
dummy_output = model(dummy_input)
model_trace = torch.jit.trace(model, dummy_input)
model_script = torch.jit.script(model)
# 跟踪法与直接 torch.onnx.export(model, ...)等价
torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output)
# 记录法必须先调用 torch.jit.sciprt
torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)
if __name__ == '__main__':
test_script_and_trace()
模型部署二 、模型导出方法
最新推荐文章于 2023-08-30 14:18:49 发布