模型量化:8-bit/4-bit 量化、PTQ 与 QAT
1. 模型量化概述
模型量化(Quantization)是一种减少模型存储大小和计算复杂度的方法,通常用于嵌入式设备和边缘计算。量化的主要目标是用更低的精度(如 INT8 或 FP16)来表示模型权重和激活值,从而加速推理并降低存储需求。
2. 量化方法
2.1 8-bit/4-bit 量化(INT8、FP16)
-
FP16(半精度浮点数)
- 使用 16-bit 浮点数代替 32-bit 浮点数。
- 提供较小的模型存储需求,同时仍保持较高精度。
- 适用于 GPU 和支持 FP16 的硬件(如 NVIDIA TensorRT)。
-
INT8(8-bit 整数量化)
- 将 32-bit 浮点数转换为 8-bit 整数。
- 可显著提高计算速度,特别适用于 CPU 和加速器(如 ARM、Edge TPU)。
- 需要进行量化缩放(scale)和零点(zero-point)调整。
-
4-bit 量化(INT4)
- 进一步降低计算精度,用 4-bit 代替 8-bit。
- 适用于超低功耗设备,但可能会损失较多精度。
2.2 训练后量化(Post-Training Quantization,PTQ)
- 先训练全精度(FP32)模型。
- 训练后使用量化技术转换权重和激活值。
- 适用于计算能力有限的设备。
- 主要方法:
- 动态量化(Dynamic Quantization)
- 静态量化(Static Quantization)
- 整数(INT-only)量化
2.3 量化感知训练(Quantization-Aware Training,QAT)
- 在训练过程中模拟量化效果。
- 量化过程包含在前向传播和反向传播中,以减少量化误差。
- 适用于高精度需求的边缘 AI 设备,如高端嵌入式处理器。
3. PyTorch 量化优化案例:MobileNetV2 量化
3.1 任务概述
- 数据集:CIFAR-10
- 模型:MobileNetV2
- 目标:
- 使用 PTQ 进行 INT8 量化
- 使用 QAT 进行优化
3.2 加载预训练模型
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练的 MobileNetV2
model = models.mobilenet_v2(pretrained=True