目录
1. 什么是ONNX格式?为什么要使用ONNX格式的模型?
Open Neural Network Exchange(ONNX)是一种开放的深度学习模型交换格式。ONNX格式可以实现不同框架(比如Pytorch和Tensorflow)之间的模型互操作和跨平台部署。使用ONNX格式可以还可以提高模型的推理效率,并有广泛的社区支持。
2. Pytorch模型导出ONNX格式的基本方法
2.1. 使用Pytorch构建一个简单的CNN
首先,用Pytorch定义一个简单的 CNN 模型,该模型包含两个卷积层、两个池化层和一个全连接层。将文件命名为SimpleCNN.py,代码如下:
import torch
import torch.nn as nn
# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) # 输入1通道(灰度图),输出32通道
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # 输入32通道,输出64通道
self.fc1 = nn.Linear(64 * 7 * 7, 128) # 64个7x7特征图展平后传入全连接层
self.fc2 = nn.Linear(128, 10) # 最终输出10个类别
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2) # 池化层,2x2
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(x.size(0), -1) # 展平操作
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
这是非常基本的Pytorch编程,这里不再赘述。不明白的小伙伴可以参考Pytorch的入门教程,网络上有很多。
2.2. 导出ONNX格式的CNN
Pytorch模型导出ONNX格式有很多细节与注意事项,比如是否有多个输入,是否有动态维度等等。这里暂时只考虑有单个输入和输出的静态模型,更复杂的情况在之后的文章中讨论。注意这里需要建立一个伪输入张量(dummy_input)来模拟模型的输入。
import torch.onnx
from SimpleCNN import SimpleCNN
# 创建CNN模型
model = SimpleCNN()
# 将模型设置为评估模式(至关重要)
model.eval()
# 创建一个输入张量(假设输入为 1 个 28x28 的灰度图像)
dummy_input = torch.randn(1, 1, 28, 28) # batch_size=1, 1通道,28x28的图像
# 导出为 ONNX 格式
onnx_file_path = "simple_cnn.onnx"
torch.onnx.export(model, # 要导出的模型
dummy_input, # 模型的输入
onnx_file_path, # 导出文件的路径
export_params=True, # 是否导出模型参数
opset_version=12, # ONNX 的操作集版本
input_names=["input"], # 输入层名称
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, # 支持动态 batch_size
"output": {0: "batch_size"}}) # 输出层名称
如代码中所示,torch.onnx.export() 方法有很多参数。这里我们需要注意的是 opset_version 和 dynamic_axes 参数。前者决定了 ONNX 操作集的版本,最新版为13;后者决定了导出的 ONNX 模型使用的是静态的还是动态的输入与输出。当然,我们可以把所有的输入和输出都设置为动态的,这样在使用起来很方便。但是这样会影响 ONNX 模型的推理速度,所以还是需要根据实际的情况谨慎选择。
2.3. 验证导出的ONNX模型
在模型导出后,可以使用 onnx 库(注意不是torch.onnx)加载和验证 ONNX 格式的模型,代码如下:
import onnx
# 加载 ONNX 模型
onnx_model = onnx.load('./simple_cnn.onnx')
# 验证模型
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")
onnx.checker.check_model 检查导出的 ONNX 模型的一致性,即模型在结构、格式和配置方面的正确性和完整性。
3. 使用Netron可视化导出的ONNX模型
ONNX格式并不像Pytorch源代码一样可以直接解读,需要额外的工具来进行模型可视化。Netron是一款非常好用的模型可视化工具,支持ONNX格式。Netron工具可以在线运行https://netron.app/。打开主页,直接加载导出的ONNX文件,ONNX模型的结构便会显示出来。
我们可以看到,Netron显示的ONNX模型结构和Pytorch源代码构造的模型结构完全一样。