onnx截取子模型,并修改模型的输出

import torch 
import onnx
from onnx import helper, TensorProto
import numpy as np
 
class Model(torch.nn.Module): 
 
    def __init__(self): 
        super().__init__() 
        self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
    def forward(self, x): 
        x = self.convs1(x) 
        x1 = self.convs2(x) 
        x2 = self.convs3(x) 
        x = x1 + x2 
        x = self.convs4(x) 
        return x 
 
model = Model() 
input = torch.randn(1, 3, 20, 20) 
 
torch.onnx.export(model, input, 'whole_model.onnx') 

onnx.utils.extract_model('/home/zhengwei/my_jupyter/lab8-lpr/Models/lprnet.onnx', 'partial_model.onnx', ['input.1'], ['/Div_output_0'])

model = onnx.load('partial_model.onnx')

div_node = None

for i in model.graph.node:
    if i.op_type == "Div":
        div_node = i


conv_w_weight = np.random.random_sample((1, 64, 1, 1))
conv_w = helper.make_tensor("conv_w", TensorProto.FLOAT, (1, 64, 1, 1), conv_w_weight)
conv_node = onnx.helper.make_node(
    "Conv",
    inputs=[div_node.output[0], "conv_w"],
    outputs=["conv_node_output"],
    kernel_shape=[1, 1],
    # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
)

model.graph.initializer.append(conv_w)
model.graph.node.append(conv_node)

new_output = onnx.helper.make_tensor_value_info('conv_node_output', onnx.TensorProto.FLOAT, [1, 1, 4, 18])
model.graph.output.extend([new_output])
model.graph.output.remove(model.graph.output[0])

onnx.save(model, 'tmp.onnx')









  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
截取一定范围内模型的顶点,你可以使用Three.js提供的BufferGeometry来获取模型的顶点数据。下面是一个简单的示例代码: ```javascript // 假设你已经加载了一个模型,并且它的几何体是一个BufferGeometry对象 // 获取模型的顶点属性 const positions = modelGeometry.attributes.position.array; // 创建一个新的BufferGeometry对象来存储截取后的顶点数据 const clippedGeometry = new THREE.BufferGeometry(); // 定义截取的范围 const minX = -10; const maxX = 10; const minY = -10; const maxY = 10; const minZ = -10; const maxZ = 10; // 存储截取后的顶点数据的数组 const clippedPositions = []; for (let i = 0; i < positions.length; i += 3) { const x = positions[i]; const y = positions[i + 1]; const z = positions[i + 2]; // 判断顶点是否在截取范围内 if (x >= minX && x <= maxX && y >= minY && y <= maxY && z >= minZ && z <= maxZ) { clippedPositions.push(x, y, z); } } // 将截取后的顶点数据设置给新的BufferGeometry对象 clippedGeometry.setAttribute('position', new THREE.Float32BufferAttribute(clippedPositions, 3)); ``` 在这个示例中,我们首先获取了模型的顶点属性数组。然后,我们遍历每个顶点,判断其坐标是否在截取范围内,并将符合条件的顶点数据存储到一个新的数组中。最后,我们将截取后的顶点数据设置给新创建的BufferGeometry对象的position属性。 请根据你的实际需求修改截取范围和其他参数。希望对你有帮助!如果你有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值