文章目录
注:大多框架的模型(pytorch、caffe2、mxnet)在加载的时候(如果有大佬知道),都需要知道输入的shape,caffe2甚至需要输入的name(caffe2只是在转onnx时需要知道input_name和input_size,如果哪位大佬知道如何在caffe2模型中获取input name或Input size可以告诉我一下),tensorflow需要知道输出的name。。。。在这方面还是cntk和caffe好
pytorch
pytorch安装
Linux和Windows现都已支持Stable(1.0)版本
官网安装
linux cpu安装:
conda install pytorch-cpu torchvision-cpu -c pytorch
版本查询:
import torch
print(torch.__version__)
pytorch转onnx
pytorch在导入模型时,需要有定义模型类的文件(.py格式)
python pytorch2onnx.py ptmodel_path ptmodel_class_py_path insize_n insize_c insize_w insize_h saved_onnx_path_name
ptmodel_path
:模型路径,pth、pkl、pt格式不限,pytorch保存模型的时候选择的是保存整个模型而不是只保存网络训练参数。
ptmodel_class_py_path
:模型类定义文件的路径(可以不放在当前路径下),py文件,为pytorch定义模型的类的文件
insize_n、insize_c、insize_w、insize_h
:输入的维度
saved_onnx_path_name
:保存为onnx模型的路径+文件名(没有路径会保存在当前路径下)
关于pytorch模型的题外话
感觉pytorch保存模型实在是太麻烦了、就算把网络结构和训练参数都保存了,依然需要在导入的时候在某处存在这个模型类的定义。对pytorch的探索可能还是太浅,现在暂且只能做到把这个类定义文件拷贝到当前路径下,用完再删除。