PyTorch模型迁移至MXNet实战:无缝对接与工程应用

有关将 PyTorch 转换为 ONNX,然后加载到 MXNet 的教程

转载自:
https://docs.aws.amazon.com/zh_cn/dlami/latest/devguide/tutorial-onnx-pytorch-mxnet.html

关于ONNX 概述

开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 受到 Amazon Web Services、Microsoft、Facebook 和其他多个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。

将 PyTorch 模型转换为 ONNX,然后将模型加载到 MXNet 中

  1. 使用文本编辑器创建一个新文件,并在脚本中使用以下程序来训练 PyTorch 中的模拟模型,然后将它导出为 ONNX 格式。
# Build a Mock Model in PyTorch with a convolution and a reduceMean layer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.onnx as torch_onnx

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False)

    def forward(self, inputs):
        x = self.conv(inputs)
        #x = x.view(x.size()[0], x.size()[1], -1)
        return torch.mean(x, dim=2)

# Use this an input trace to serialize the model
input_shape = (3, 100, 100)
model_onnx_path = "torch_model.onnx"
model = Model()
model.train(False)

# Export the model to an ONNX file
dummy_input = Variable(torch.randn(1, *input_shape))
output = torch_onnx.export(model, 
                          dummy_input, 
                          model_onnx_path, 
                          verbose=False)
print("Export of torch_model.onnx complete!")

  1. 使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 MXNet 中打开 ONNX 格式文件。
import onnx
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
import numpy as np

# Import the ONNX model into MXNet's symbolic interface
sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx")
print("Loaded torch_model.onnx!")
print(sym.get_internals())

  1. 运行此脚本后,MXNet 将拥有加载的模型,并打印一些基本模型信息。

函数介绍

将一个模型导出到ONNX格式。该exporter会运行一次你的模型,以便于记录模型的执行轨迹,并将其导出

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, dynamic_axes = {})

参数
  • model
    (torch.nn.Module)-要被导出的模型
  • args
    (参数的集合)-模型的输入,例如,这种model方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。因为export运行模型,所以我们需要提供一个输入张量x。只要是正确的类型和大小,其中的值就可以是随机的。请注意,除非指定为动态轴,否则输入尺寸将在导出的ONNX图形中固定为所有输入尺寸。在此示例中,我们使用输入batch_size 1导出模型,但随后dynamic_axes 在torch.onnx.export()。因此,导出的模型将接受大小为[batch_size,3、100、100]的输入,其中batch_size可以是可变的。
  • f
    -一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
  • export_params
    (bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
  • verbose
    (bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
  • training
    (bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
  • input_names
    (list of strings, default empty list)-按顺序分配名称到图中的输入节点。
  • output_names
    (list of strings, default empty list)-按顺序分配名称到图中的输出节点。
  • dynamic_axes
    {‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 : ‘batch_size’}}) # variable lenght axes

函数介绍

将onnx模型文件导入mxnet

mxnet.contrib.onnx.import_model(model_file)

参量
  • model_file
    (str)– ONNX模型文件名
输出
  • sym
    (Symbol)– MXNet符号对象
  • arg_params
    (dict of str to NDArray)–以mxnet.ndarray.NDArray格式存储的转换参数的字典
  • aux_params
    (dict of str to NDArray)–以mxnet.ndarray.NDArray格式存储的转换参数的字典
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值