pt转onnx
Support Matrix
torch.onnx
onnx上查看转换方法
开始转换
- 加载pt文件
torchvision支持的model 地址
model = torchvision.models.xxx()
model .load_state_dict(torch.load("xxx.pth"))
model.eval()
不支持的,以resnet50_nfc为例
model = get_model("resnet50", num_label, use_id=False, num_id=num_id) ##自定义的网络
model .load_state_dict(torch.load("xxx.pth"))
model.eval()
- onnx导出
##准备导出
dummy_input = torch.randn(3, 3, 224, 224, device='cpu')
##输入和输出的名字
input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(1)]
output_names = ["output1"]
##导出为rr.onnx
torch.onnx.export(model, dummy_input, "rr.onnx", verbose=True, input_names=input_names, output_names=output_names)
- 验证
torch.rand 和 numpy.rand 互转
x = torch.rand(2,2)
x1 = x.numpy() # torch转换到numpy
x2 = torch.from_numpy(x1) #numpy转换torch
开始验证
import onnxruntime as ort
import numpy as np
##导入onnx文件
ort_session = ort.InferenceSession('rr.onnx')
# outputs = ort_session.run(None, {'actual_input_1': np.random.randn(3, 3, 224, 224).astype(np.float32)})
##dummy_input为torch的随机矩阵,转为numpy的
outputs = ort_session.run(None, {'actual_input_1': dummy_input.numpy()})
##打印
print(model.forward(dummy_input))
print(outputs[0])
运行结果