1. onnx简介
onnx(Open Neural Network Exchange),是一种用于表示神经网络的规范的模型,便于模型在不同框架下进行转换。
2. 可视化
可视化网页:https://netron.app/
结果:
3. pytorch框架下的转换
pytorch框架下,使用torch.onnx.export()
函数进行转换。
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None,
output_names=None,aten=False, export_raw_ir=False, operator_export_type=None,
opset_version=None, _retain_param_name=True,do_constant_folding=False,
example_outputs=None, strip_doc_string=True,
dynamic_axes=None, keep_initializers_as_inputs=None)
几个比较重要的参数声明:
torch.onnx.export(model, # 网络模型,在dqn中就是保存eval_net
torch.randn(1, 3, 224, 224), # 描述输入的维数,具体数值无关紧要
export_onnx_file, # 输出onnx的名称,也可以限定位置
input_names=["input"], # 输入节点的名称,可以不写,写就要对应上
output_names=["output"], # 输出节点的名称,可以不写,写就要对应上
)
4. pytorch下DQN网络转换样例与注意事项
dummy_input = torch.randn([1,N_STATES])
torch.onnx.export(dqn.eval_net, dummy_input, r".\models\dqn.onnx")
r".\models\dqn.onnx"
,通过相对地址的写法,限定了onnx生成的位置与名称
要注意,DQN输入虽然只有N_STATES维,但转换之后的onnx网络却是一个1*N_STATES维的输入,需要通过torch.randn([1,N_STATES])
转换成恰当的格式,如果不添加,则会生成如下网络: