- pytorch模型转ONNX模型
import torch
class ConNet(torch.nn.Module):
def __init__(self):
super(ConNet, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.AvgPool2d(2, 2)
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, 2)
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(32*7*7, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
torch.nn.ReLU()
)
self.out = torch.nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
output = self.out(x)
return output
convnet = ConNet()
convnet.eval()
x = torch.randn(1, 3, 28, 28) # 生成输入张量
export_onnx_file = "../convnet.onnx" # ONNX文件名
torch.onnx.export(
convnet,
x,
export_onnx_file,
opset_version=10, #操作集版本,默认为9
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 模型输入名
output_names=["output"], # 模型输出名
)
运行后生成
- ONNX模型推理
import torch
import numpy as np
import onnxruntime as rt
import time
def onnx_runtime():
start = time.time()
imgdata = np.random.randn(1, 3, 28, 28).astype(np.float32)
sess = rt.InferenceSession('../convnet.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
pred_onnx = sess.run([output_name], {input_name: imgdata})
end = time.time()
pred_onnx = pred_onnx[0][0]
print("outputs:")
print(pred_onnx)
print("推理用时", end - start)
onnx_runtime()
- ONNX模型可视化
使用在线可视化工具即可
https://www.machunjie.com/dl/Visualization/index.html