【pytorch】将训练好的模型部署至生产环境:onnx及onnxruntime使用

68 篇文章 2 订阅
65 篇文章 3 订阅

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。各类框架中的模型,通过ONNX进行转化见下图所示:
在这里插入图片描述
导出函数及参数介绍:

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False)
模型(torch.nn.Module) - 要导出的模型。
args(参数元组) - 模型的输入,例如,这-model(*args)是模型的有效调用。任何非变量参数将被硬编码到导出的模型中; 任何变量参数都将成为输出模型的输入,按照它们在参数中出现的顺序。如果args是一个变量,这相当于用该变量的一个元组来调用它。(注意:将关键字参数传递给模型目前还不支持,如果需要,给我们留言。)

f - 类文件对象(必须实现返回文件描述符的fileno)或包含文件名的字符串。一个二进制Protobuf将被写入这个文件。
export_params(布尔,默认为True- 如果指定,所有参数将被导出。如果要导出未经训练的模型,请将其设置为False。在这种情况下,导出的模型将首先将其所有参数作为参数,按照指定的顺序model.state_dict().values()

verbose(布尔,默认为False- 如果指定,我们将打印出一个调试描述的导出轨迹。

training(布尔,默认为False- 在训练模式下导出模型。目前,ONNX只是为了推导出口模型,所以你通常不需要将其设置为True

深度学习模型实际上就是一个计算图。模型部署时通常把模型转换成静态的计算图,即没有控制流(分支语句、循环语句)的计算图。
注意,torch.onnx.export中需要的模型实际上是一个torch.jit.ScriptModule。而要把普通 PyTorch 模型转一个这样的 TorchScript 模型,有跟踪(trace)和记录(script)两种导出计算图的方法。如果给torch.onnx.export传入了一个普通 PyTorch 模型(torch.nn.Module),那么这个模型会默认使用跟踪的方法导出。(记录法需使用model_script = torch.jit.script(model)将模型预先做简化 )

跟踪法和记录法的区别在于:
跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);记录法则能通过解析模型来正确记录所有的控制流。

此外,对于网络输入输出可能出现的动态情况,应设置动态维度:
其中,将输入输出的某些维度设置为动态维度(默认全是静态)。

import torch 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv = torch.nn.Conv2d(3, 3, 3) 
 
    def forward(self, x): 
        x = self.conv(x) 
        return x 
 
 
model = Model() 
dummy_input = torch.rand(1, 3, 10, 10) 
model_names = ['model_static.onnx',  
'model_dynamic_0.onnx',  
'model_dynamic_23.onnx'] 
 
dynamic_axes_0 = { 
    'in' : [0], 
    'out' : [0] 
} 
dynamic_axes_23 = { 
    'in' : [2, 3], 
    'out' : [2, 3] 
} 
 
torch.onnx.export(model, dummy_input, model_names[0],  
input_names=['in'], output_names=['out']) 
torch.onnx.export(model, dummy_input, model_names[1],  
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0) 
torch.onnx.export(model, dummy_input, model_names[2],  
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23) 

使模型在 ONNX 转换时有不同的行为:
torch.onnx.is_in_onnx_export()标识符仅在模型导出时为真(如将后处理步骤也引入torch.nn.Module模型中),如:

import torch 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv = torch.nn.Conv2d(3, 3, 3) 
 
    def forward(self, x): 
        x = self.conv(x) 
        if torch.onnx.is_in_onnx_export(): 
        #数值限制在[0, 1]之间
            x = torch.clip(x, 0, 1) 
        return x 

下面的代码可验证生成的onnx格式是否正确:

import onnx 
 
onnx_model = onnx.load("srcnn.onnx") 
try: 
    onnx.checker.check_model(onnx_model) 
except Exception: 
    print("Model incorrect") 
else: 
    print("Model correct")

ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,即”推理引擎“。 ONNX Runtime 可以直接读取并运行 .onnx 文件。
以下是python接口示例代码:

#model代码略
x = torch.randn(1, 3, 256, 256) 
 
with torch.no_grad(): 
    torch.onnx.export( 
        model, 
        x, 
        "srcnn.onnx", 
        opset_version=11, 
        input_names=['input'], 
        output_names=['output'])
        
import onnxruntime 
 #onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession("srcnn.onnx") 
ort_inputs = {'input': input_img} 
#其第一个参数为输出张量名的列表,第二个参数为输入值的字典
ort_output = ort_session.run(['output'], ort_inputs)[0] 

ort_output = np.squeeze(ort_output, 0) 
ort_output = np.clip(ort_output, 0, 255) 
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8) 
cv2.imwrite("face_ort.png", ort_output)

注意,转换为onnx时,输入参数必须为tensor,输入参数及模型必须做相应处理:

#具体类节选:
    def forward(self, x, upscale_factor): 
    #使用.item(),将tensor变量转为普通python变量
        x = interpolate(x, 
                        scale_factor=upscale_factor.item(), 
                        mode='bicubic', 
                        align_corners=False) 
 
... 
 
#使用时
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy() 
 
... 
 
with torch.no_grad(): 
#将常数3改为torch.tensor(3)
    torch.onnx.export(model, (x, torch.tensor(3)), 
                      "srcnn2.onnx", 
                      opset_version=11, 
                      input_names=['input', 'factor'], 
                      output_names=['output']) 

但是注意,这样做,使upscale_factor.item(),在追踪时,无法追到普通python变量,使得第二个参数,仍被以常量3固定下来。
使用自定义算子:替换interpolate

class NewInterpolate(torch.autograd.Function): 
    @staticmethod 
    def symbolic(g, input, scales): 
        return g.op("Resize", 
                    input, 
                    g.op("Constant", 
                         value_t=torch.tensor([], dtype=torch.float32)), 
                    scales, 
                    coordinate_transformation_mode_s="pytorch_half_pixel", 
                    cubic_coeff_a_f=-0.75, 
                    mode_s='cubic', 
                    nearest_mode_s="floor") 
 
    @staticmethod 
    def forward(ctx, input, scales): 
        scales = scales.tolist()[-2:] 
        return interpolate(input, 
                           scale_factor=scales, 
                           mode='bicubic', 
                           align_corners=False) 
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值