在成功安装TVM环境后,我们尝试编译运行一个简单的手写数字识别模型。模型使用tvm自带的tests/micro/testdata/mnist/mnist-8.onnx。测试图片是从https://raw.githubusercontent.com/junehui/ImageProcessing/master/MNIST_data/MNIST_data.zip
下载的。编写模型编译运行代码如下:
import onnx
import numpy as np
import tvm
from tvm import te
import tvm.relay as relay
from tvm.contrib import graph_executor
from PIL import Image
onnx_model = onnx.load('mnist-8.onnx')
image_path = '0.png'
img = Image.open(image_path).resize((28, 28))
img = img.convert('L')
img = np.array(img, dtype=np.float32)
data = img.reshape((1, 1, 28, 28))
target = "llvm"
input_name = "Input3"
shape_dict = {input_name: data.shape}
# 导入onnx模型
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
# 设置优化级别
with tvm.transform.PassContext(opt_level=3):
#编译模型
lib = relay.build(mod, target, params=params)
# 创建运行模块,指定在cpu0上运行模型
module = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
# 设置运行模块输入
module.set_input(input_name, data)
# 运行编译好的模型
module.run()
# 获取输出
out = module.get_output(0).numpy()
print(out)
输出
[[ 5648.1436 -2949.1753 649.1323 -3821.0854 -1681.0052 -169.5099
1740.3281 -1816.3844 -1328.9271 -646.9918]]
代码中主要有三部分:
1. from_onnx导入onnx模型,转换为tvm relay ir形式;
2. 配置模型优化pass环境;
3. 编译模型;
4. 使用图执行器运行编译好的模型。