抛开pytorch,抛开一切,从现在开始,我们只关注ONNX
完全用 ONNX 的 Python API 构造简单的两个输入一个输出的 ONNX 模型。
1. 构建输入和输出的信息
helper.make_tensor_value_info
可以看到这个函数需要[name, elem_type, shape, doc_string, shape_denotation]
name :自定义名称
elem_type :定义类型,我们用FLOAT
shape :[600,200,8,1,64] 我们假模假样的给一个
后面两个参数先不管了,本文这三个参数就够用;
开始构建
input_0 = helper.make_tensor_value_info('input_0', TensorProto.FLOAT, [600,200,8,1,64])
input_1 = helper.make_tensor_value_info('input_1', TensorProto.FLOAT, [600,1,8,64,1])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [600,200,8,1,1])
2. 构建节点信息
helper.make_node
op_type : 算子类型 :这里有坑,但不影响本文的学习,所以先不管。
inputs : 列表,包含了所有输入的name,我们这里是两个
outputs :列表, 包含了所有输出的name,标准的ONNX都是一个输出
name : 自定义节点名称
够用了
matmul = helper.make_node('MatMul', ['input_0', 'input_1'], ['output'], 'MatMul_0')
3. 构建图信息
helper.make_graph
nodes : 节点信息
name : 自定义图名称
inputs : 输入的名称
output : 输出的名称
够了
graph = helper.make_graph([matmul], 'torch-jit-export', [input_0, input_1], [output])
4. 构建模型
model = helper.make_model(graph)
5. 保存导出
onnx.save(model, 'Gemv.onnx')
6. 查看效果
7. 代码
import onnx
from onnx import helper
from onnx import TensorProto
# 两个输入一个输出
input_0 = helper.make_tensor_value_info('input_0', TensorProto.FLOAT, [600,200,8,1,64])
input_1 = helper.make_tensor_value_info('input_1', TensorProto.FLOAT, [600,1,8,64,1])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [600,200,8,1,1])
# 构建节点信息
matmul = helper.make_node('MatMul', ['input_0', 'input_1'], ['output'], 'MatMul_0')
# 构建图信息
graph = helper.make_graph([matmul], 'torch-jit-export', [input_0, input_1], [output])
# 构建模型
model = helper.make_model(graph)
# onnx.checker.check_model(model)
print(model)
onnx.save(model, 'Gemv.onnx')
百练成砖,一起成长。