Tensor Rt的int8量化原理

本文介绍了Tensor RT的INT8量化原理,旨在减少模型计算的精度损失,提高推理速度。主要内容包括线性量化和对称线性量化方法,探讨了如何通过优化阈值选择,特别是利用相对熵和KL散度进行校准,以最小化信息损失。同时,详细阐述了Tensor RT的量化流程,包括收集统计量、执行校准算法和生成INT8推理引擎。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

量化的目标

  • 把神经网络运算的32浮点表示的权重,变成8为的Int整数,并且希望没有显著的准确率下降
  • 为什么要采用In8,因为它可以带来更高吞吐率,并且更少内存占用
  • 但是也面临挑战,Int8更低的精度,并且有更小的动态范围
  • 如何保证量化后的准确率呢,解决方案 : 对Int8量化后的模型权重和激活函数,进行最小化信息损失。
  • Tensor RT采用的方法,不需要额外的fine tuning 或重新训练。

In8推理

挑战

  • INT8 相对于FP32具有较低的精度和动态范围
    在这里插入图片描述
  • 从表中可以看出32位浮点,16位浮点,INT8 的动态范围有很大的不同,比如16位点是-65504 ~ +65504
### PyTorch MNIST 模型 INT8 量化的实现 对于希望减少计算资源消耗并提高推理速度的应用场景而言,INT8量化是一种有效的方法。通过降低权重和激活值的精度8位整数表示,可以在保持较高准确性的同时显著提升性能。 在PyTorch中执行INT8量化涉及几个重要步骤: #### 准备环境与加载预训练模型 为了确保后续操作顺利进行,先安装必要的库,并导入所需的模块[^1]。 ```python import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader from tqdm import tqdm ``` 接着定义转换函数以及数据集加载器来获取MNIST测试集用于校准过程: ```python transform = transforms.Compose([ transforms.ToTensor(), ]) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) calibration_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) ``` #### 定义辅助类来进行统计收集 创建一个简单的观察者类以记录每一层的最大最小值范围,这对于后续确定合适的缩放因子至关重要: ```python class MinMaxObserver(torch.nn.Module): def __init__(self): super(MinMaxObserver, self).__init__() self.min_val = float('inf') self.max_val = -float('inf') def forward(self, x): min_x = torch.min(x).item() max_x = torch.max(x).item() if min_x < self.min_val: self.min_val = min_x if max_x > self.max_val: self.max_val = max_x return x ``` #### 应用量化感知训练 (QAT) 或静态量化方法 这里展示的是基于静态仿真的方式,在此之前需确保已有一个经过充分训练好的浮点版本模型实例 `model` 可供使用。应用量化配置前记得设置为评估模式: ```python model.eval() # Switch the model into evaluation mode. quantized_model = torch.quantization.convert(model.to('cpu'), inplace=False) ``` 如果采用动态量化,则只需指定哪些类型的层需要被处理;而对于静态量化来说,除了上述之外还需要额外提供代表性的输入样本以便于调整比例尺参数: ```python # Static Quantization Preparation fused_model = torch.quantization.fuse_modules( copy.deepcopy(model), [['conv1', 'relu1'], ['conv2', 'relu2']]) qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')} prepared_quantized_model = torch.quantization.prepare_qat(fused_model, qconfig_spec=qconfig_dict) for images, _ in calibration_loader: prepared_quantized_model(images) final_quantized_model = torch.quantization.convert(prepared_quantized_model.cpu().eval(), inplace=False) ``` 完成以上流程之后便得到了适用于部署阶段使用的低比特宽度网络结构——即完成了从FP32至INT8的数据类型转变工作[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

@BangBang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值