PyTorch和ONNX模型相互转换

      目前遇到一个问题,有一个ONNX的网络模型,可以通过Netron进行可视化,Netron地址为Netron

       模型的可视化结构如下:

       模型一共有三层,两层卷积运算,一层全连接,每个卷积运算的卷积核和输入尺寸和输出尺寸都可以看到。但是看到这层网络有一个问题,在全连接之前,缺少一个Reshape算子,将1*14*8*8的特征图变成一维向量,将向量送入全连接,得到相应的输出【1,10】

       现在有任务是将ONNX网络模型转换为PyTorch网络模型,共计有两步:

       1. 根据可视化的网络模型进行Pytorch框架的网络模型重建(结构)

class Model_mnist(torch.nn.Module):
    def __init__(self):
        super(Model_mnist, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
        self.conv2 = torch.nn.Conv2d(10, 14, kernel_size=(5, 5), stride=(3, 3), padding=(0, 0))
        self.fc1 = torch.nn.Linear(896, 10, bias=True)

    def forward(self, x):
        out = self.conv1(x) // 1024
        out = F.relu(out)
        out = self.conv2(out) // 1024
        out = F.relu(out)
        out = out.reshape(1, -1)
        out = self.fc1(out) // 1024

        return out

       2. 读取ONNX模型的权重及Bias,权重和bias都存储在model_onnx.graph.initializer中

Conv_3_bias = np.frombuffer(initializer[0].raw_data, dtype=np.float32).reshape(1, 10)
Conv_3_weight = np.frombuffer(initializer[1].raw_data, dtype=np.float32).reshape(10, 1, 3, 3)
Conv_5_bias = np.frombuffer(initializer[2].raw_data, dtype=np.float32).reshape(1, 14)
Conv_5_weight = np.frombuffer(initializer[3].raw_data, dtype=np.float32).reshape(14, 10, 5, 5)
fc10_weight = np.frombuffer(initializer[4].raw_data, dtype=np.float32).reshape(10, 896)
fc10_bias = np.frombuffer(initializer[5].raw_data, dtype=np.float32).reshape(1, 10)

         使用Parameter进行权重的赋值

        根据建立好的PyTorch网络模型可以实现模型的正向推理,因此可以在Python代码上进行推理验证,完整版的代码如下所示:

# -*- coding: utf-8 -*-
# @Time : 2023-06-12 17:27
# @Author : Zander
# @Email : 1091574181@qq.com
# @File : onnx_2_pytorch.py

import onnx
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchvision
import torchvision.transforms as transformers


class Model_mnist(torch.nn.Module):
    def __init__(self):
        super(Model_mnist, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0))
        self.conv2 = torch.nn.Conv2d(10, 14, kernel_size=(5, 5), stride=(3, 3), padding=(0, 0))
        self.fc1 = torch.nn.Linear(896, 10, bias=True)

    def forward(self, x):
        out = self.conv1(x) // 1024
        out = F.relu(out)
        out = self.conv2(out) // 1024
        out = F.relu(out)
        out = out.reshape(1, -1)
        out = self.fc1(out) // 1024

        return out


model_name = "model_onnx/wit_mnist_preprocess.onnx"
model_onnx = onnx.load_model(model_name)
model_torch = Model_mnist()
initializer = model_onnx.graph.initializer
for i in range(len(initializer)):
    raw_data = np.frombuffer(initializer[i].raw_data, dtype=np.float32)

Conv_3_bias = np.frombuffer(initializer[0].raw_data, dtype=np.float32).reshape(1, 10)
Conv_3_weight = np.frombuffer(initializer[1].raw_data, dtype=np.float32).reshape(10, 1, 3, 3)
Conv_5_bias = np.frombuffer(initializer[2].raw_data, dtype=np.float32).reshape(1, 14)
Conv_5_weight = np.frombuffer(initializer[3].raw_data, dtype=np.float32).reshape(14, 10, 5, 5)
fc10_weight = np.frombuffer(initializer[4].raw_data, dtype=np.float32).reshape(10, 896)
fc10_bias = np.frombuffer(initializer[5].raw_data, dtype=np.float32).reshape(1, 10)
model_torch.conv1.weight = Parameter(torch.tensor(Conv_3_weight))
model_torch.conv1.bias = Parameter(torch.tensor(Conv_3_bias).squeeze())
model_torch.conv2.weight = Parameter(torch.tensor(Conv_5_weight))
model_torch.conv2.bias = Parameter(torch.tensor(Conv_5_bias).squeeze())
model_torch.fc1.weight = Parameter(torch.tensor(fc10_weight))
model_torch.fc1.bias = Parameter(torch.tensor(fc10_bias).squeeze())
torch.save(model_torch.state_dict(), './torch_model/wit_mnist.pth')
mnist_test = torchvision.datasets.MNIST(root=r'C:\work\Code\Protobuf_test\mnist', train=False,
                                        download=True, transform=transformers.ToTensor())
print(len(mnist_test))
for i in range(len(mnist_test)):
    img, lable = mnist_test[i]
    img = img.unsqueeze_(0)
    x = img * 255
    y = model_torch(x).squeeze()
    print(torch.argmax(y), lable)
img, lable = mnist_test[0]
img = img.unsqueeze_(0)
torch.onnx.export(model_torch,
                  img,
                  "wit_mnist_add_rashape.onnx",
                  opset_version=11,
                  input_names=["input"],  # 输入名
                  output_names=["output"])  # 输出名

最后的函数是将PyTorch模型转换为ONNX

  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值