目前遇到一个问题,有一个ONNX的网络模型,可以通过Netron进行可视化,Netron地址为Netron
模型的可视化结构如下:
模型一共有三层,两层卷积运算,一层全连接,每个卷积运算的卷积核和输入尺寸和输出尺寸都可以看到。但是看到这层网络有一个问题,在全连接之前,缺少一个Reshape算子,将1*14*8*8的特征图变成一维向量,将向量送入全连接,得到相应的输出【1,10】
现在有任务是将ONNX网络模型转换为PyTorch网络模型,共计有两步:
1. 根据可视化的网络模型进行Pytorch框架的网络模型重建(结构)
class Model_mnist(torch.nn.Module):
def __init__(self):
super(Model_mnist, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
self.conv2 = torch.nn.Conv2d(10, 14, kernel_size=(5, 5), stride=(3, 3), padding=(0, 0))
self.fc1 = torch.nn.Linear(896, 10, bias=True)
def forward(self, x):
out = self.conv1(x) // 1024
out = F.relu(out)
out = self.conv2(out) // 1024
out = F.relu(out)
out = out.reshape(1, -1)
out = self.fc1(out) // 1024
return out
2. 读取ONNX模型的权重及Bias,权重和bias都存储在model_onnx.graph.initializer中
Conv_3_bias = np.frombuffer(initializer[0].raw_data, dtype=np.float32).reshape(1, 10)
Conv_3_weight = np.frombuffer(initializer[1].raw_data, dtype=np.float32).reshape(10, 1, 3, 3)
Conv_5_bias = np.frombuffer(initializer[2].raw_data, dtype=np.float32).reshape(1, 14)
Conv_5_weight = np.frombuffer(initializer[3].raw_data, dtype=np.float32).reshape(14, 10, 5, 5)
fc10_weight = np.frombuffer(initializer[4].raw_data, dtype=np.float32).reshape(10, 896)
fc10_bias = np.frombuffer(initializer[5].raw_data, dtype=np.float32).reshape(1, 10)
使用Parameter进行权重的赋值
根据建立好的PyTorch网络模型可以实现模型的正向推理,因此可以在Python代码上进行推理验证,完整版的代码如下所示:
# -*- coding: utf-8 -*-
# @Time : 2023-06-12 17:27
# @Author : Zander
# @Email : 1091574181@qq.com
# @File : onnx_2_pytorch.py
import onnx
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchvision
import torchvision.transforms as transformers
class Model_mnist(torch.nn.Module):
def __init__(self):
super(Model_mnist, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
self.conv2 = torch.nn.Conv2d(10, 14, kernel_size=(5, 5), stride=(3, 3), padding=(0, 0))
self.fc1 = torch.nn.Linear(896, 10, bias=True)
def forward(self, x):
out = self.conv1(x) // 1024
out = F.relu(out)
out = self.conv2(out) // 1024
out = F.relu(out)
out = out.reshape(1, -1)
out = self.fc1(out) // 1024
return out
model_name = "model_onnx/wit_mnist_preprocess.onnx"
model_onnx = onnx.load_model(model_name)
model_torch = Model_mnist()
initializer = model_onnx.graph.initializer
for i in range(len(initializer)):
raw_data = np.frombuffer(initializer[i].raw_data, dtype=np.float32)
Conv_3_bias = np.frombuffer(initializer[0].raw_data, dtype=np.float32).reshape(1, 10)
Conv_3_weight = np.frombuffer(initializer[1].raw_data, dtype=np.float32).reshape(10, 1, 3, 3)
Conv_5_bias = np.frombuffer(initializer[2].raw_data, dtype=np.float32).reshape(1, 14)
Conv_5_weight = np.frombuffer(initializer[3].raw_data, dtype=np.float32).reshape(14, 10, 5, 5)
fc10_weight = np.frombuffer(initializer[4].raw_data, dtype=np.float32).reshape(10, 896)
fc10_bias = np.frombuffer(initializer[5].raw_data, dtype=np.float32).reshape(1, 10)
model_torch.conv1.weight = Parameter(torch.tensor(Conv_3_weight))
model_torch.conv1.bias = Parameter(torch.tensor(Conv_3_bias).squeeze())
model_torch.conv2.weight = Parameter(torch.tensor(Conv_5_weight))
model_torch.conv2.bias = Parameter(torch.tensor(Conv_5_bias).squeeze())
model_torch.fc1.weight = Parameter(torch.tensor(fc10_weight))
model_torch.fc1.bias = Parameter(torch.tensor(fc10_bias).squeeze())
torch.save(model_torch.state_dict(), './torch_model/wit_mnist.pth')
mnist_test = torchvision.datasets.MNIST(root=r'C:\work\Code\Protobuf_test\mnist', train=False,
download=True, transform=transformers.ToTensor())
print(len(mnist_test))
for i in range(len(mnist_test)):
img, lable = mnist_test[i]
img = img.unsqueeze_(0)
x = img * 255
y = model_torch(x).squeeze()
print(torch.argmax(y), lable)
img, lable = mnist_test[0]
img = img.unsqueeze_(0)
torch.onnx.export(model_torch,
img,
"wit_mnist_add_rashape.onnx",
opset_version=11,
input_names=["input"], # 输入名
output_names=["output"]) # 输出名
最后的函数是将PyTorch模型转换为ONNX