【模型转化】修改onnx节点属性

    在之前的一篇文章【模型转换】onnx转tensorrt报错:Attribute not found: axes中提到squeeze操作在不明确指定axes参数时onnx转tensorrt会报错。解决办法也很简单,因为我是先将pytorch转onnx,再转的tensorrt。pytorch网络结构中,给squeeze操作指定好axes参数再重新生成onnx即可。实际上我们还可以借助onnx的helper功能直接修改onnx节点属性来解决这个问题。

    例如,我要将卷积层的输出[1x64x40000x1]通过Squeeze操作进行维度压缩,剔除维数为1的维度,输出tensor[64x40000]给Transpose层。

import onnx                                                                        
model = onnx.load('my.onnx')
for node_id,node in enumerate(model.graph.node):
    print("######%s######" % node_id)                                                                                                                                                                                                                                 
    print(node)                                                                   

通过简单操作,我们可以很快定位到onnx中Squeeze的基本信息。

        
######44######
input: "383"
input: "vfe.pfn_layers.0.conv3.weight"
output: "384"
name: "Conv_249"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 3
  type: INTS
}           
attribute {
  name: "group"
  i: 1  
  type: INT
}           
attribute {
  name: "kernel_shape"
  ints: 1
  ints: 11
  type: INTS
}           
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}           
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}           
            
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
            
######46######
input: "385"
output: "386"
name: "Transpose_251"
op_type: "Transpose"
attribute {
  name: "perm"
  ints: 1
  ints: 0
  type: INTS
}           
            

 可见,相比前后的Conv层和Transpose层,Squeeze层压根就没有属性信息。这里通过onnx的helper功能给该节点加上属性,指定要降维的轴。我这里因为只是要修改第1个Squeeze,所以可以如下修改:

model = onnx.load('my.onnx')
for node_id,node in enumerate(model.graph.node):                                                                                                                                                                                                                      
    if node.op_type == "Squeeze":
        attr = onnx.helper.make_attribute("axes",[0,3])
        node.attribute.insert(0,attr)
        break

再看:

            
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
attribute {
  name: "axes"
  ints: 0
  ints: 3
  type: INTS
}           
            

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值