onnx截取子模型,并修改模型的输出

import torch 
import onnx
from onnx import helper, TensorProto
import numpy as np
 
class Model(torch.nn.Module): 
 
    def __init__(self): 
        super().__init__() 
        self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
    def forward(self, x): 
        x = self.convs1(x) 
        x1 = self.convs2(x) 
        x2 = self.convs3(x) 
        x = x1 + x2 
        x = self.convs4(x) 
        return x 
 
model = Model() 
input = torch.randn(1, 3, 20, 20) 
 
torch.onnx.export(model, input, 'whole_model.onnx') 

onnx.utils.extract_model('/home/zhengwei/my_jupyter/lab8-lpr/Models/lprnet.onnx', 'partial_model.onnx', ['input.1'], ['/Div_output_0'])

model = onnx.load('partial_model.onnx')

div_node = None

for i in model.graph.node:
    if i.op_type == "Div":
        div_node = i


conv_w_weight = np.random.random_sample((1, 64, 1, 1))
conv_w = helper.make_tensor("conv_w", TensorProto.FLOAT, (1, 64, 1, 1), conv_w_weight)
conv_node = onnx.helper.make_node(
    "Conv",
    inputs=[div_node.output[0], "conv_w"],
    outputs=["conv_node_output"],
    kernel_shape=[1, 1],
    # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
)

model.graph.initializer.append(conv_w)
model.graph.node.append(conv_node)

new_output = onnx.helper.make_tensor_value_info('conv_node_output', onnx.TensorProto.FLOAT, [1, 1, 4, 18])
model.graph.output.extend([new_output])
model.graph.output.remove(model.graph.output[0])

onnx.save(model, 'tmp.onnx')









  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值