onnx模型修改:将均值和方差放到模型中

训练模型时,一般都会对原始数据进行归一化再送入网络,即减均值和除方差。在部署时,我们也要进行同样的操作。有些推理框架会提供对应的接口,我们只需要设置均值和方差即可,如MNN.也有一些框架不提供这样的功能,如Tensorrt,这时,我们就需要自己去逐像素进行这个操作,不仅繁琐,还可能比较耗时。还有一种方式是将这个操作放到模型中,一个方法是在我们的原始pytorch模型中增加一个固定参数的Batchnorm层,另一种方式就是本文要讲的在导出的onnx模型中插入Sub和Div节点来完成。

插入节点

主要步骤:
1.创建2个常量节点,分别是均值和方差向量
2.分别插入一个Sub和Div节点,Sub节点输入是模型的输入和均值节点,Div节点输入是Sub节点输出和方差向量
3.将输入层后的第一层的输入修改为Div节点的输出。
代码如下:

import onnx
from onnx import numpy_helper
import numpy as np

# 加载ONNX模型
model_path = "xx.onnx"
model = onnx.load(model_path)
# 创建均值和方差张量
mean_value = [0.485*255, 0.456*255, 0.406*255]
variance_value = [0.229*255, 0.224*255, 0.225*255]

mean_tensor = numpy_helper.from_array(np.array(mean_value, dtype=np.float32).reshape(1,3,1,1), "mean")
variance_tensor = numpy_helper.from_array(np.array(variance_value, dtype=np.float32).reshape(1,3,1,1), "variance")

# 插入均值和方差节点
mean_node = onnx.helper.make_node("Constant", [], ["mean"], value=mean_tensor)
variance_node = onnx.helper.make_node("Constant", [], ["variance"], value=variance_tensor)

model.graph.node.insert(0, mean_node)
model.graph.node.insert(1, variance_node)

# 插入归一化节点
input_name = model.graph.input[0].name
normalize_node = onnx.helper.make_node("Sub", [input_name, "mean"], ["sub_output"])
scale_node = onnx.helper.make_node("Div", ["sub_output", "variance"], ["input_norm"])

# 插入节点到模型中
model.graph.node.insert(2, normalize_node)
model.graph.node.insert(3, scale_node)

# 更新模型
model.graph.node[4].input[0] = "input_norm"

# 保存修改后的ONNX模型
modified_model_path = "xx_norm.onnx"
# shape inference
model = onnx.shape_inference.infer_shapes(model)
# check model
onnx.checker.check_model(model)
onnx.save(model, modified_model_path)

在这里插入图片描述

验证模型结果

比对修改前和修改后的模型输出

import onnxruntime as ort
import numpy as np

# 加载原始模型
origin_ort = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
origin_input_name = origin_ort.get_inputs()[0].name
origin_output_name = origin_ort.get_outputs()[0].name
input = np.random.randn(1, 3, 128, 128).astype(np.float32)
input_norm = (input - np.array(mean_value,np.float32).reshape(1, 3, 1, 1)) / np.array(variance_value,np.float32).reshape(1, 3, 1, 1)
origin_output = origin_ort.run([origin_output_name], {origin_input_name: input_norm})[0]

# 加载修改后的模型
modified_ort = ort.InferenceSession(modified_model_path, providers=["CPUExecutionProvider"])
modified_input_name = modified_ort.get_inputs()[0].name
modified_output_name = modified_ort.get_outputs()[0].name
modified_output = modified_ort.run([modified_output_name], {modified_input_name: input})[0]

# 比较两个模型的输出
print(np.allclose(origin_output, modified_output, atol=1e-3))

输出为True证明添加节点后的模型正确,后续使用使,不再需要在外部进行均值和方差操作。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
【课程简介】 本课程适合所有对金融知识和MATLAB感兴趣的同学,通过本课程,你不仅可以学习到如何应用MATLAB,还可以学习到如何使用MATLAB进行金融数据处理与金融数据分析 【完整课程列表】 基于MATLAB的金融数据分析 金融MATLAB-第01,02章 金融市场与金融产品 MATLAB基础知识(共47页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第03章 MATLAB与Excel文件的数据交换(共41页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第05章 贷款按揭与保险产品 现金流分析案例(共44页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第06章 随机模拟 概率分布与随机数(共33页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第07章 cftool数据拟合 GDP与用电量增速分析(共22页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第08章 策略模拟 组合保险策略分析(共32页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第09章 KMV模型求解 方程与方程组的数值解(共31页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第10章 期权定价模型与数值方法(共23页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第12章 马克维兹均值模型(共19页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第13章 投资组合绩效(共22页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第17章 固定收益证券的久期与凸度(共12页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第18章 利率的期限结构(共7页).ppt 基于MATLAB的金融数据分析 金融MATLAB-第22章 技术分析 指标计算与绘图(共8页).ppt

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CodingInCV

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值