ONNX系列: ONNX模型修改

ONNX 模型修改

        当我们熟悉了ONNX模型各个层级的结构后,我们便可以针对各个结构来对模型进行修改,从而使其更好的适配后端运行时或者特定硬件平台的编译器。对模型的修改通常可以概括为"增删改查"的操作。"增"是增加相应结构,"删"是删除相应结构,"改"是修改相应结构,"查"是获取到指定的模型结构。修改ONNX模型通常有两种思路,一是使用ONNX官方提供的Python API;二是使用第三方ONNX模型修改工具,例如onnx-graphsurgeon工具。本文将聚焦第一种方案,介绍如何使用ONNX官方API来对ONNX模型进行"增删改查"修改。完整的ONNX官方API文档可以参考:https://onnx.ai/onnx/index.html。

1. ONNX 模型的"查"

我们想要修改ONNX模型,首先需要知道如何定位到自己感兴趣的位置,比如如何找到具体某个节点、某个 initializer,、计算图的input/output, 某个节点的 input/output以及某个 value_info。参考下面的代码,我们可以发现定位某个元素的基本思路就是遍历该元素的列表,然后根据该元素在计算图中独有的属性名称来实现定位。下面的代码实现了定位下图模型的各个元素。

# 根据算子的名字来找到目标节点

for item in model.graph.node:

    if item.name == 'Conv_1':

        print(item)

# 有的onnx模型中算子没有name属性,可以根据算子类型和输出的名字来组合找到目标节点

for item in model.graph.node:

    if item.op_type == 'Conv':

        if '1338' in item.output:

            print(item)

# 找到目标 intializer

for i in model.graph.initializer:

    if i.name == '1339':

        print(i.dims)

        print(i.dims)

        print(i.data_type)

        # 二进制形式打印,可能比较长

        print(i.raw_data)

# 找到 graph 的input和output

for i in model.graph.input:

    if i.name == 'input':

        print(i.name)

        print(i.type)

# 找到 graph 的valueinfo

for i in model.graph.value_info:

    if i.name == '9':

        print(i.name)

        print(i.type)

2. ONNX 模型的"删"

在了解了如何定位到需要修改的部分后,我们就可以对ONNX模型进行魔改了。我们首先了解如何删除ONNX模型中的指定节点或元素。下面的代码实现了删除图中标注的节点。

import onnx

# 加载模型

model = onnx.load('./super-resolution-10.onnx')

# 根据输入获取指定节点

def get_node_with_input(model, input_name):

    res = []

    for i in model.graph.node:

        if input_name in i.input:

            res.append(i)

    return res

# 根据输出获取指定节点

def get_node_with_output(model, output_name):

    res = []

    for i in model.graph.node:

        if output_name in i.output:

            res.append(i)

    return res

# 删除指定节点并将前后节点连接起来

remove_nodes = []

p = None

n = None

for i in model.graph.node:

    if '10' in i.input:

        # p = find_node_with_output(i.input[0])

        p = get_node_with_output(model, i.input[0])[0]

        remove_nodes.append(i)

    if '11' in i.input:

        # n = find_node_with_input(i.output[0])

        n = get_node_with_input(model, i.output[0])[0]

        remove_nodes.append(i)

n.input[0] = p.output[0]

for i in remove_nodes:

    model.graph.node.remove(i)

onnx.checker.check_model(model)

onnx.save(model, 'super-resolution-10-delete.onnx')

3. ONNX 模型的"增"

"增"是指在ONNX模型指定位置添加节点。在了解添加节点之前,我们首先需要了解如何创建 ONNX 节点。下面以创建一个2D卷积算子和一个ReLu算子为例,并尝试将上一步骤中删除的这两个节点重新添加回模型当中(注意我们权重没有与原模型保持一致)。

node1 = onnx.helper.make_node(

        name="Conv_0",   # 节点名字,不要和op_type搞混了

        op_type="Conv",  # 节点的算子类型, 比如'Conv'、'Relu'、'Add'这类,详细可以参考onnx给出的算子列表

        inputs=["image", "conv.weight", "conv.bias"],  # 各个输入的名字,结点的输入包含:输入和算子的权重。必有输入X和权重W,偏置B可以作为可选。

        outputs=["11"], 

        pads=[1, 1, 1, 1], # 其他字符串为节点的属性,attributes在官网被明确的给出了,标注了default的属性具备默认值。

        group=1,

        dilations=[1, 1],

        kernel_shape=[3, 3],

        strides=[1, 1]

    )

initializer_w = onnx.helper.make_tensor(

        name="conv.weight",

        data_type=onnx.helper.TensorProto.DataType.FLOAT,

        dims=[64, 64, 3, 3],

        vals=np.ones([64,64,3,3], dtype=np.float32).tobytes(),

        raw=True

    )

initializer_b = onnx.helper.make_tensor(

        name="conv.bias",

        data_type=onnx.helper.TensorProto.DataType.FLOAT,

        dims=[64],

        vals=np.ones([64], dtype=np.float32).tobytes(),

        raw=True

    )

node2 = onnx.helper.make_node(

        name="ReLU_1",

        op_type="Relu",

        inputs=["11"],

        outputs=["12"]

    )

下面代码将上述创建的两个节点插入到模型指定位置。

for i in range(len(model.graph.node)):

    if '10' in model.graph.node[i].output:

        model.graph.node[i].output[0] = 'pre_output'

        model.graph.node[i+1].input[0] = 'relu_output'

        model.graph.node.insert(i+1, node1)

        model.graph.node.insert(i+2, node2)

model.graph.initializer.append(initializer_w)

model.graph.initializer.append(initializer_b)

input = model.graph.input[0]

new_input = onnx.helper.make_tensor_value_info(input.name, onnx.TensorProto.FLOAT, [1,1,224,224])

model.graph.input[0].CopyFrom(new_input)

onnx.checker.check_model(model)

model = onnx.shape_inference.infer_shapes(model)

onnx.save(model, 'super-resolution-10-insert.onnx')

4. ONNX 模型的"改"

通常来说修改 ONNX 模型可以概括为一下两种:

  • 修改模型节点
  • 修改权重(initializer)

修改模型的节点可以通过上述的删除 + 添加节点组合操作来实现,这里不再赘述。下面将介绍如何修改节点权重。节点权重通常保存在initializer中,下面代码尝试将Conv算子中的bias缩小10倍。

import onnx

model = onnx.load("./super-resolution-10.onnx")

# 得到所有 initializer

all_initializer = model.graph.initializer

# 定位到目标 initializer

target_initializer = 'conv1.bias'

idx = ''

scale_factor = 10

for i, j in enumerate(all_initializer):

    if j.name == target_initializer:

        idx = i

        break

# 将 conv1 算子的 bias 缩小10倍

model.graph.initializer[idx].raw_data = (onnx.numpy_helper.to_array(all_initializer[idx]) / scale_factor).tobytes()

onnx.save(model,'super-resolution-10-scale.onnx')

总结

当我们在实际部署模型时,会根据具体硬件特性来在 ONNX 模型层面做相应的优化修改,使其能在特定的硬件平台上获得更好的推理性能。本文简单介绍了如何调用 ONNX 官方API来对 ONNX 模型进行增删改查,更加复杂的模型修改操作通常是上述四种操作的各种组合。

使用ONNX 官方API需要我们对 ONNX 模型的定义和Proto结构足够熟悉,并且通过本文中的示例代码可以看到,繁多复杂的API在使用过程中也不是很方便。在实际工作中,我们一般使用NV提供的onnx-graphsurgeon工具来快速对ONNX模型进行修改验证。这个工具在官方ONNX API的基础上提供了更为友好的高级API封装,大大提升了我们修改ONNX模型的效率,在之后的文章中我们将进一步详细介绍这个工具的使用。

作者:高通工程师,阮慧源(Huiyuan Ruan)

  • 13
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值