PPQ量化onnx格式demo

该代码示例展示了如何对一个ONNX模型进行量化处理,特别是多输入模型。它首先生成校准数据集,然后使用PPQ库进行量化。在量化过程中,进行了图形错误分析,并与ONNXRuntime的结果进行了比较,以评估量化后的模型性能。最后,计算了量化模型与原始模型之间的SNR误差。
摘要由CSDN通过智能技术生成
# 量化一个onnx模型(多输入模型),执行误差分析,并与onnxruntime对齐结果
# 此时Calibration Dataset应是一个list of dictionary
from typing import Iterable,Tuple
import torch
from ppq import (BaseGraph,QuantizationSettingFactory,TargetPlatform,
	convert_any_to_numpy,torch_snr_error)
from ppq.api import (dispatch_graph,export_ppq_graph,load_onnx_graph,
	quantize_onnx_model)
from ppq.core.data import convert_any_to_torch_tensor
from ppq.executor.torch import TorchExecutor
from ppq.quantization.analyse.graphwise import graphwise_error_analyse
INPUT_SHAPES = {'input.1': [1,3,224,224]}
DEVICE = 'cuda'
QUANT_PLATFORM = TargetPlatform.TRT_INT8
ONNX_PATH = 'model.onnx'
ONNX_OUTPUT_PATH = 'Output/model.onnx'
def generate_calibration_dataset(graph: BaseGraph,num_of_batches: int = 32) -> Tuple[Iterable[dict],torch.Tensor]:
	dataset = []
	for i in range(num_of_batches):
		sample = {name: torch.rand(INPUT_SHAPES[name]) for name in graph.inputs}
		dataset.append(sample)
	return dataset,sample
def collate_fn(batch: dict) -> torch.Tensor:
	return {k: v.to(DEVICE) for k,v in batch.items()}
# 建立QuantizationSetting对象管理量化过程
# 将调度方法修改为conservative,且要求ppq启动量化微调
QSetting = QuantizationSettingFactory.default_setting()
QSetting.lsq_optimization = False
# 加载模型并要求ppq按照规则完成图调度
graph = load_onnx_graph(onnx_import_file=ONNX_PATH)
graph = dispatch_graph(graph=graph,platform=QUANT_PLATFORM)
for name in graph.inputs:
	if name not in INPUT_SHAPES:
		raise KeyError(f'Graph Input {name} needs a valid shape.')
if len(graph.outputs) != 1:
	raise ValueError('This Script Requires graph to have only 1 output.')
# 生成校准所需的数据集,准备开始网络量化
calibration_dataset,sample = generate_calibration_dataset(graph)
quantized = quantize_onnx_model(
	onnx_import_file=ONNX_PATH,calib_dataloader=calibration_dataset,
	calib_steps=32,input_shape=None,inputs=collate_fn(sample),
	setting=QSetting,collate_fn=collate_fn,platform=QUANT_PLATFORM,
	device=DEVICE,verbose=0)
# 完成ppq量化之后,保存ppq网络执行结果
# 将对比ppq与onnxruntime执行结果是否相同
executor,reference_outputs = TorchExecutor(quantized), []
for sample in calibration_dataset:
	reference_outputs.append(executor.forward(collate_fn(sample)))
graphwise_error_analyse(
	graph=quantized,running_device=DEVICE,
	collate_fn=collate_fn,dataloader=calibration_dataset)
export_ppq_graph(graph=quantized,platform=TargetPlatform.ONNXRUNTIME,
	graph_save_to=ONNX_OUTPUT_PATH)
try:
	import onnxruntime
except ImportError as e:
	raise Exception('Onnxruntime is not installed.')
sess = onnxruntime.InferenceSession(ONNX_OUTPUT_PATH,providers=['CUDAExecutionProvider'])
output_name = sess.get_outputs()[0].name
onnxruntime_outputs = []
for sample in calibration_dataset:
	onnxruntime_outputs.append(sess.run(
		output_names=[output_name],
		input_feed={k:convert_any_to_numpy(v) for k,v in sample.items()}))
y_pred,y_real = [], []
for reference_output,onnxruntime_output in zip(reference_outputs,onnxruntime_outputs):
	y_pred.append(convert_any_to_torch_tensor(reference_output[0],device='cpu').unsqueeze(0))
	y_real.append(convert_any_to_torch_tensor(onnxruntime_output[0],device='cpu').unsqueeze(0))
y_pred = torch.cat(y_pred,dim=0)
y_real = torch.cat(y_real,dim=0)
print(f'Simulating Error For {output_name}: {torch_snr_error(y_pred=y_pred,y_real=y_real).item() * 100 :.4f}%')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值