PPQ 中analyse使用demo

# 误差分析
from typing import Iterable
import torch
import torchvision
from ppq import TargetPlatform,graphwise_error_analyse
from ppq.api import quantize_torch_model
# 调用ppq的cuda内核函数以提高速度。它们比Torch内核快5-100倍,而且GPU内存成本更低。
from ppq.api.interface import ENABLE_CUDA_KERNEL
# 统计工作,强大的分析功能对网络进行深入研究 返回值是一个统计参数的集合,可以用pandas进行处理
from ppq.quantization.analyse.graphwise import statistical_analyse
from ppq.quantization.analyse.layerwise import layerwise_error_analyse

BATCHSIZE =32
INPUT_SHAPE = [BATCHSIZE,3,224,224]
DEVICE = 'cuda'
PLATFORM = TargetPlatform.TRT_INT8
# PPQ需要送入32-1024个样本数据作为校准数据集
# 它们应尽可能服从真实样本的分布,量化过程如同训练过程一样可能存在过拟合问题
# 应当保证校准数据是经过正确预处理的、有代表性的数据,否则量化将会失败;校准数据不需要标签,数据集不能乱序
def load_calibration_dataset() -> Iterable:
	return [torch.rand(size=INPUT_SHAPE) for _ in range (32)]
CALIBRATION = load_calibration_dataset()
def collate_fn(batch: torch.Tensor) -> torch.Tensor:
	return batch.to(DEVICE)
# 使用mobilenetv2作为样例模型
# PPQ使用torch.onnx.export函数把pytorch模型转换为onnx
# 对于复杂的pytorch模型,你或许需要自己完成pytorch模型到onnx的转换
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)
# 如果使用ENABLE_CUDA_KERNEL方法
# PPQ将尝试编译高性能量化算子,这一过程需要编译环境的支持
# 如果在编译过程中发生错误,你可以删除此处对于ENABLE_CUDA_KERNEL方法的调用
# 这将显著降低PPQ的运算速度,但即使无法编译这些算子,你仍然可以使用pytorch的gpu算子完成量化
with ENABLE_CUDA_KERNEL():
	quantized = quantize_torch_model(
		model=model,calib_dataloader=CALIBRATION,
		calib_steps=32,input_shape=INPUT_SHAPE,
		collate_fn=collate_fn,platform=PLATFORM,
		onnx_export_file='Output/onnx.model',device=DEVICE,verbose=0)
	# graphwise_error_analyse是最常用的分析方法,它分析网络中的量化误差情况,将结果直接打印在屏幕上
	# 对于graphwise_error_analyse而言,算子的误差直接衡量了量化网络与浮点网络之间的输出误差
	# 这一误差是累积的,意味着网络后面的算子总是会比网络前面的算子拥有更高的输出误差
	# 留意网络输出的误差情况,如果你想获得一个精度较高的量化网络,那么靠近输出节点的误差不应超过10%
	# 该方法只衡量Conv,Gemm算子的误差情况,如果对其余算子的误差感兴趣,需手动修改方法逻辑
	reports = graphwise_error_analyse(
		graph=quantized,running_device=DEVICE,collate_fn=collate_fn,
		dataloader=CALIBRATION)
	# layerwise_error_analyse是更为强大的分析方法,它分析算子的量化敏感性
	# 与graphwise_error_analyse不同,该方法分析的误差不是累计的
	# 该方法首先解除网络中所有算子的量化,而后单独地量化每一个Conv,Gemm算子
	# 以此来衡量量化单独一个算子对网络输出的影响情况,该方法常被用来决定网络调度与混合精度量化
	# 可以将那些误差较大的层送往TargetPlatform.FP32
	reports = layerwise_error_analyse(
		graph=quantized,running_device=DEVICE,collate_fn=collate_fn,
		dataloader=CALIBRATION)
	# statistical_analyse是强有力的统计分析方法,该方法统计每一层的输入、输出以及参数的统计分布情况
	# 使用这一方法,将更清晰了解网络的量化情况,并能有针对性的选择优化方案
	# 推荐在网络量化不佳时,使用statistical_analyse辅助分析
	# 该方法不打印任何数据,需要手动将数据保存到硬盘并进行分析
	report = statistical_analyse(
		graph=quantized,running_device=DEVICE,
		collate_fn=collate_fn,dataloader=CALIBRATION)
	from pandas import DataFrame
	report = DataFrame(report)
	report.to_csv('1.csv')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值