需要安装MindSpore工具包进行下面操作。MindSpore安装地址:https://www.mindspore.cn/install/
在使用MindSpore进行模型训练过程后,如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。
MINDIR:MindSpore框架的一种基于图表示的函数式IR,其最核心的目的是服务于自动微分变换,目前可用于MindSpore Lite端侧推理。
AIR:全称Ascend Intermediate Representation,类似ONNX,是华为定义的针对机器学习所设计的开放式的文件格式,能更好地适配Ascend AI处理器。
ONNX:全称Open Neural Network Exchange,是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。
以下通过实例来介绍保存CheckPoint格式文件和导出MINDIR、AIR和ONNX格式文件的方法。
-
导出MINDIR格式文件方法
保存了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。当前支持基于静态图,且不包含控制流语义的推理网络导出。导出该格式文件的代码样例如下:
from mindspore.train.serialization import export
import numpy as np
resnet = ResNet50()
# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
# load the parameter into net
load_param_into_net(resnet, param_dict)
input = np.random.uniform(0.0, 1.0, size=[32, 3, 224, 224]).astype(np.float32)
export(resnet, Tensor(input), file_name='resnet50-2_32.mindir', file_format='MINDIR')
建议使用.mindir作为MINDIR格式文件的后缀名。
-
导出AIR格式文件方法
如果想继续在昇腾AI处理器上做推理,需要通过网络和CheckPoint生成对应的AIR格式模型文件。导出该格式文件的代码样例如下:
from mindspore.train.serialization import export
import numpy as np
resnet = ResNet50()
# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
# load the parameter into net
load_param_into_net(resnet, param_dict)
input = np.random.uniform(0.0, 1.0, size=[32, 3, 224, 224]).astype(np.float32)
export(resnet, Tensor(input), file_name='resnet50-2_32.air', file_format='AIR')
使用export接口之前,需要先导入mindspore.train.serialization。
input用来指定导出模型的输入shape以及数据类型。
建议使用.air作为AIR格式文件的后缀名。
-
导出ONNX格式文件方法
如果想继续在昇腾AI处理器、GPU、CPU等多种硬件上做推理,需要通过网络和CheckPoint生成对应的ONNX格式模型文件。导出该格式文件的代码样例如下:
from mindspore.train.serialization import export
import numpy as np
resnet = ResNet50()
# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
# load the parameter into net
load_param_into_net(resnet, param_dict)
input = np.random.uniform(0.0, 1.0, size=[32, 3, 224, 224]).astype(np.float32)
export(resnet, Tensor(input), file_name='resnet50-2_32.onnx', file_format='ONNX')
建议使用.onnx作为ONNX格式文件的后缀名。