onnx-modifier详解(好用的onnx模型修改工具)

网页交互修改onnx模型的工具

架构

概述

  1. onnx-modifier,对于每一个Node,是动态读取attribute,可以修改每个attribute的值,没有增加attribute的功能
  2. onnx-modifier,open model是前后端分离的,模型load两次,后端load一次,前端load一次,前端做各种操作(rename、 add node、add output、modify attribute 、delete node (未真正删除,只修改属性)),然后将这些操作记录为json,传入后端,进行模型修改,导出
  3. 流程图

问题

  • onnx-modifier,怎么增加attribute
  • 修改js代码
  • 在用onnx-modifier之前增加一步,给所有节点加上量化参数
  • onnx-modifier解析模型时自动判断有没有量化节点,没有就push一个
  • 增加一个按钮,可以增加attribute

详解

open model

index.html第323行,给出了Open Model调用的是open-file-button

<button id="consent-accept-button" class="center consent-accept-button">Accept</button>
<button id="open-file-button" class="center open-file-button">Open Model&hellip;</button>
<button id="github-button" class="center github-button">Github repo</button>

open-file-button,见static/index.js第294行,打开file,判断后缀是否为pb onnx pth pdmodel nb h5,若是,则加载模型,见第18行和30行,目前来看是加载两次,

  1. 后端/open_model,调用onnx-modifier.py
def open_model():
    # https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask
    onnx_file = request.files['file']

    global onnx_modifier
    onnx_modifier = onnxModifier.from_name_stream(onnx_file.filename, onnx_file.stream)

    return 'OK', 200

  1. 前端view.js调用,见static/index.js第597行
_open(file, files) {
  this._view.show('welcome spinner');
  const context = new host.BrowserHost.BrowserFileContext(this, file, files);
  context.open().then(() => {
    return this._view.open(context).then((model) => {
      this._view.show(null);
      this.document.title = files[0].name;
      return model;
    });
  }).catch((error) => {
    this._view.error(error, null, null);
  });
}
    open(context) {
        this._host.event('Model', 'Open', 'Size', context.stream ? context.stream.length : 0);
        this._sidebar.close();
        return this._timeout(2).then(() => {
            // _modelFactoryService 各个框架模型注册
            return this._modelFactoryService.open(context).then((model) => {
                const format = [];
                if (model.format) {
                    format.push(model.format);
                }
                if (model.producer) {
                    format.push('(' + model.producer + ')');
                }
                if (format.length > 0) {
                    this._host.event('Model', 'Format', format.join(' '));
                }
                return this._timeout(20).then(() => {
                    const graphs = Array.isArray(model.graphs) && model.graphs.length > 0 ? [ model.graphs[0] ] : [];
                    return this._updateGraph(model, graphs); //updategraph下文介绍
                });
            });
        });
    }
const openFileButton = this.document.getElementById('open-file-button');
const openFileDialog = this.document.getElementById('open-file-dialog');
if (openFileButton && openFileDialog) {
  openFileButton.addEventListener('click', () => {
    openFileDialog.value = '';
    openFileDialog.click();
  });
  openFileDialog.addEventListener('change', (e) => {
    if (e.target && e.target.files && e.target.files.length > 0) {
      const files = Array.from(e.target.files);
      const file = files.find((file) => this._view.accept(file.name));
      // console.log(file)
      this.upload_filename = file.name;
      var form = new FormData();
      form.append('file', file);

      // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
      fetch('/open_model', {
        method: 'POST',
        body: form
      }).then(function (response) {
        return response.text();
      }).then(function (text) {
        console.log('POST response: ');
        // Should be 'OK' if everything was successful
        console.log(text);
      });


      if (file) {
        this._open(file, files);
      }
    }
  });
}

前端交互

四个最重要按钮
<button id="refresh-graph" class="graph-op-button-refresh">Refresh</button>
<button id="reset-graph" class="graph-op-button-reset">Reset</button>
<button id="download-graph" class="graph-op-button-download">Download</button>
<button id="add-node" class="graph-op-button-addNode">Add node</button>
  1. refresh-graph
const refreshButton = this.document.getElementById('refresh-graph');
refreshButton.addEventListener('click', () => {
  this._view._updateGraph();
})
_updateGraph(model, graphs) {
  const lastModel = this._model;
  const lastGraphs = this._graphs;
  // update graph if and only if `model` and `graphs` are provided
  if (model && graphs) {
    this._model = model;
    this._graphs = graphs;

    this.UpdateAddNodeDropDown();
  }
  this.lastViewGraph = this._graph; 
  const graph = this.activeGraph; // 主要看这个函数
  // console.log(graph.nodes)

  return this._timeout(100).then(() => {
    if (graph && graph != lastGraphs[0]) {
      const nodes = graph.nodes;
      // console.log(nodes);
      if (nodes.length > 2048) {
        if (!this._host.confirm('Large model detected.', 'This graph contains a large number of nodes and might take a long time to render. Do you want to continue?')) {
          this._host.event('Graph', 'Render', 'Skip', nodes.length);
          this.show(null);
          return null;
        }
      }
    }
    const update = () => {
      const nameButton = this._getElementById('name-button');
      const backButton = this._getElementById('back-button');
      if (this._graphs.length > 1) {
        const graph = this.activeGraph;
        nameButton.innerHTML = graph ? graph.name : '';
        backButton.style.opacity = 1;
        nameButton.style.opacity = 1;
      }
      else {
        backButton.style.opacity = 0;
        nameButton.style.opacity = 0;
      }
    };
    return this.renderGraph(this._model, this.activeGraph).then(() => {
      if (this._page !== 'default') {
        this.show('default');
      }
      update();
      return this._model;
    }).catch((error) => {
      this._model = lastModel;
      this._graphs = lastGraphs;
      return this.renderGraph(this._model, this.activeGraph).then(() => {
        if (this._page !== 'default') {
          this.show('default');
        }
        update();
        throw error;
      });
    });
  });
}

如下,做了四个功能:

  1. 刷新新增node的输入输出
  2. 刷新model的输入输出
  3. 刷新node的输入输出
  4. 刷新属性的参数
    get activeGraph() {
        // return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
        var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
        if (active_graph && this.lastViewGraph) {
            this.refreshAddedNode()
            this.refreshModelInputOutput()
            this.refreshNodeArguments()
            this.refreshNodeAttributes()
        }

        return active_graph
    }

  1. reset-graph

从代码可以看到,reset graph做了下面几步:

  1. delete node 还原
  2. add node 还原
  3. add output 还原

属性修改没有还原

    resetGraph() {
        // reset node states
        for (const nodeId of this.nodes.keys()) {
            const node = this.node(nodeId);
            this._modelNodeName2State.set(node.label.modelNodeName, 'Exist')
        }
        
        // console.log(this._renameMap)
        // reset node inputs/outputs
        for (const changed_node_name of this._renameMap.keys()) {
            var node = this._modelNodeName2ModelNode.get(changed_node_name)
            // console.log(node)
            // console.log(typeof node)
            // console.log(node.constructor.name)
            if (node.arguments) {   // model input or model output. Because they are purely onnx.Parameter
                node.arguments[0] = this.view._graphs[0]._context.argument(node.modelNodeName)
            }
            
            else {                   // model nodes
                //reset inputs
                for (var input of node.inputs) {
                    for (var i = 0; i < input.arguments.length; ++i) {
                        // console.log(input.arguments[i].original_name)
                        if (this._renameMap.get(node.modelNodeName).get(input.arguments[i].original_name)) {
                            input.arguments[i] = this.view._graphs[0]._context.argument(input.arguments[i].original_name)
                        }
                    }
                }
                
                // reset outputs
                for (var output of node.outputs) {
                    for (var i = 0; i < output.arguments.length; ++i) {
                        if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) {
                            output.arguments[i] = this.view._graphs[0]._context.argument(output.arguments[i].original_name)
                        }
                    }
                }

            }
        }
        this._renameMap = new Map();

        // clear custom added nodes
        this._addedNode = new Map()
        this.view._graphs[0].reset_custom_added_node()
        this._addedOutputs = []
        this.view._graphs[0].reset_custom_added_outputs()
    }
  1. download-graph
        const downloadButton = this.document.getElementById('download-graph');
        downloadButton.addEventListener('click', () => {

            // console.log(this._view._graph._addedNode)
            // console.log(this._view._graph._renameMap)
            // // https://healeycodes.com/talking-between-languages
            fetch('/download', {
                // Declare what type of data we're sending
                headers: {
                  'Content-Type': 'application/json'
                },
                // Specify the method
                method: 'POST',
                body: JSON.stringify({
                    'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
                    'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
                    'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
                    'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
                    'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, 
                                                                this._view._graph._renameMap, this._view._graph._modelNodeName2State)),
                    'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo)
                })
            }).then(function (response) {
                return response.text();
            }).then(function (text) {
                console.log('POST response: ');
                // Should be 'OK' if everything was successful
                console.log(text);
                if (text == 'OK') {
                    // alert("Modified model has been successfuly saved in ./modified_onnx/");
                    swal("Success!", "Modified model has been successfuly saved in ./modified_onnx/", "success");
                }
                else {
                    // swal("Error happens!", "You are kindly to create an issue on https://github.com/ZhangGe6/onnx-modifier", "error");
                    swal("Error happens!", "You are kindly to check the log and create an issue on https://github.com/ZhangGe6/onnx-modifier", "error");
                    // alert('Error happens, you can find it out or create an issue on https://github.com/ZhangGe6/onnx-modifier')
                }
            });
        });
  1. add-node
const addNodeButton = this.document.getElementById('add-node');
addNodeButton.addEventListener('click', () => {
    // this._view._graph.resetGraph();
    // this._view._updateGraph();
    var addNodeDropDown = this.document.getElementById('add-node-dropdown');
    var selected_val = addNodeDropDown.options[addNodeDropDown.selectedIndex].value
    var add_op_domain = selected_val.split(':')[0]
    var add_op_type = selected_val.split(':')[1]
    // console.log(selected_val)
    this._view._graph.add_node(add_op_domain, add_op_type)
    this._view._updateGraph();
})

    add_node(op_domain, op_type) {

        var node_id = (this._add_nodeKey++).toString();  // in case input (onnx) node has no name
        var modelNodeName = 'custom_added_' + op_type + node_id

        var properties = new Map()
        properties.set('domain', op_domain)
        properties.set('op_type', op_type)
        properties.set('name', modelNodeName)
        this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties))
        // console.log(this._addedNode)

    }
node 属性值修改、input output名修改
   生成node界面
sidebar.NodeSidebar = class {

    constructor(host, node, modelNodeName) {
        this._host = host;
        this._node = node;
        this._modelNodeName = modelNodeName;
        this._elements = [];
        this._attributes = [];
        this._inputs = [];
        this._outputs = [];

        if (node.type) {
            let showDocumentation = null;
            const type = node.type;
            if (type && (type.description || type.inputs || type.outputs || type.attributes)) {
                showDocumentation = {};
                showDocumentation.text = type.nodes ? '\u0192': '?';
                showDocumentation.callback = () => {
                    this._raise('show-documentation', null);
                };
            }
            this._addProperty('type', new sidebar.ValueTextView(this._host, node.type.name, showDocumentation));
            if (node.type.module) {
                this._addProperty('module', new sidebar.ValueTextView(this._host, node.type.module));
            }
        }

        if (node.name) {
            this._addProperty('name', new sidebar.ValueTextView(this._host, node.name));
        }

        if (node.location) {
            this._addProperty('location', new sidebar.ValueTextView(this._host, node.location));
        }

        if (node.description) {
            this._addProperty('description', new sidebar.ValueTextView(this._host, node.description));
        }

        if (node.device) {
            this._addProperty('device', new sidebar.ValueTextView(this._host, node.device));
        }

        const attributes = node.attributes;
        if (attributes && attributes.length > 0) {
            const sortedAttributes = node.attributes.slice();
            sortedAttributes.sort((a, b) => {
                const au = a.name.toUpperCase();
                const bu = b.name.toUpperCase();
                return (au < bu) ? -1 : (au > bu) ? 1 : 0;
            });
            this._addHeader('Attributes');
            for (const attribute of sortedAttributes) {
                this._addAttribute(attribute.name, attribute);
            }
        }

        const inputs = node.inputs;
        if (inputs && inputs.length > 0) {
            this._addHeader('Inputs');
            for (const [index, input] of inputs.entries()){
            // for (const input of inputs) {
                this._addInput(input.name, input, index);  // 这里的input.name是小白格前面的名称(不是方格内的)
            }
        }

        const outputs = node.outputs;
        if (outputs && outputs.length > 0) {
            this._addHeader('Outputs');
            for (const [index, output] of outputs.entries()){
            // for (const output of outputs) {
                this._addOutput(output.name, output, index);
            }
        }

        this.add_separator(this._elements, 'sidebar-view-separator')
        this._elements.push(this._host.document.createElement('hr'));
        this.add_separator(this._elements, 'sidebar-view-separator')

        this._addHeader('Node deleting helper');
        this._addButton('Delete With Children');
        this.add_span()
        this._addButton('Delete Single Node');
        this.add_span()
        this._addButton('Recover Node');
        this.add_separator(this._elements, 'sidebar-view-separator')
        this._addButton('Enter');
    
        this._addHeader('Output adding helper');
        this._addButton('Add Output');
        
        // deprecated
        // this.add_separator(this._elements, 'sidebar-view-separator');
        // this._addHeader('Rename helper');
        // if (inputs && inputs.length > 0) {
        //     for (const input of inputs) {
        //         this.add_rename_aux_element(input.arguments);
        //     }
        // }
        // if (outputs && outputs.length > 0) {
        //     for (const output of outputs) {
        //         this.add_rename_aux_element(output.arguments);
        //     }
        // }

        // this.add_separator(this._elements, 'sidebar-view-separator');
        // this._addHeader('Add children node');
        // this._addDropdownSelector('AddChildrenNode');
        // this.add_span()
        // this._addButton('Add Node');

    }
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值