PPQ中calibration使用demo

# 这个例子展示PPQ的校准方法,这些校准方法将计算出网络的量化参数
from typing import Iterable
import torch
import torchvision
from ppq import TargetPlatform,graphwise_error_analyse
from ppq.api.interface import (ENABLE_CUDA_KERNEL,dump_torch_to_onnx,
			load_onnx_graph,quantize_native_model)
from ppq.api.setting import QuantizationSettingFactory

BATCHSIZE = 32
INPUT_SHAPE = [BATCHSIZE,3,224,224]
DEVICE = 'cuda'
PLATFORM = TargetPlatform.TRT_INT8
def load_calibration_dataset() -> Iterable:
	return [torch.rand(INPUT_SHAPE) for _ in range (32)]
CALIBRATION = load_calibration_dataset()
def collate_fn(batch: torch.Tensor) -> torch.Tensor:
	return batch.to(DEVICE)
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)
# PPQ提供k1,mse,minmax,isotone,percentile(默认)五种校准方法
# 每一种校准方法还有更多参数可供调整,PPQ也允许你单独调整某一层的量化校准方法
# 首先展示以QSetting的方法调整量化校准参数(推荐)
QSetting = QuantizationSettingFactory.default_setting()
# 激活值校准算法,不区分大小写,可以选择minmax,k1,percentile,MSE,None
# 选择None时,将由quantizer指定量化算法。
QSetting.quantize_activation_setting.calib_algorithm = 'k1'
# 参数校准算法,不区分大小写,可以选择minmax,k1,percentile,MSE,None
# 选择None时,将由quantizer指定量化算法。
# 同时还设置了选项:是否处理被动量化参数,是否执行参数烘焙
QSetting.quantize_parameter_setting.calib_algorithm = 'minmax'
# 当你选择某种校准方法,可以进入ppq.core.common
# OBSERVER_KL_HIST_BINS,OBSERVER_PERCENTILE,OBSERVER_HIST_BINS皆是与校准方法有关的可调整的参数
# OBSERVER_KL_HIST_BINS - KL 算法相关的箱子个数,可以将其调整为512,1024,2048,4096,8192...
# OBSERVER_PERCENTILE - Percentile 算法相关百分比,可以将其调整为0.9999,0.9995,0.99999,0.99995...
# OBSERVER_KL_HIST_BINS - MSE 算法相关的箱子个数,可以将其调整为512,1024,2048,4096,8192...
with ENABLE_CUDA_KERNEL():
	# 转换一个torch模型到onnx,并保存到指定位置
	dump_torch_to_onnx(model=model,onnx_export_file='Output/model.onnx',
		input_shape=INPUT_SHAPE,input_dtype=torch.float32)
	# 从一个指定位置加载onnx计算图,注意该加载的计算图尚未经过调度,此时所有算子被认为是可量化的
	graph=load_onnx_graph(onnx_import_file='Output/model.onnx')
	# 量化一个已经在内存中的ppq模型 输入一个量化前的PPQ.IR.BaseGraph 返回一个量化后的PPQ.IR.BaseGraph
	quantized=quantize_native_model(
		model=graph,calib_dataloader=CALIBRATION,
		calib_steps=32,input_shape=INPUT_SHAPE,
		collate_fn=collate_fn,platform=PLATFORM,
		device=DEVICE,verbose=0)
reports=graphwise_error_analyse(
	graph=quantized,running_device=DEVICE,collate_fn=collate_fn,
	dataloader=CALIBRATION)

代码注释中的箱子个数是什么之后再研究... 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值