torch.fx 量化支持——FX GRAPH MODE QUANTIZATION
torch.fx目前支持的量化方式:
- Post Training Quantization
- Weight Only Quantization
- Dynamic Quantization
- Static Quantization
- Quantization Aware Training
- Static Quantization
其中,Post Training Quantization中的Static Quantization和Dynamic Quantization提供了demo。
与Eager模式对比
简单来说,fx提供一个Graph模式:
- 可以自动插入量化节点(如quantize和dequantize),不需要手动修改当前的network及forward
- 这个模式下可以看到forward是怎么被自动构建的,可以进行更精细的调整
Graph模式
局限:只有可以转换为符号的部分(symbolically traceable)可以被量化,Data dependent control flow是不支持的。如果模型有些部分无法被符号化,则量化只能在模型的部分上工作,不能被符号化的部分会被跳过。
如果需要这些部分被量化:
- 重写代码让这些部分symbolically traceable
- 将这些部分转换成observed和quantized的子模块
相关的具体操作见(PROTOTYPE) FX GRAPH MODE QUANTIZATION USER GUIDE。
训练后量化尝试
环境准备:
import torch
import copy
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
步骤
- 准备训练好的权重、数据及网络模型
- 初始化网络,加载训练好的权重(一般用copy.deepcopy保留原始模型),并将其置于eval模式:
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()
model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()
- 指定量化模型的qconfig_dict
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
qconfig是QConfig的一个实例,QConfig这个类就是维护了两个observer,一个是activation所使用的observer,一个是op权重所使用的observer。
backend | activation | weight |
---|---|---|
fbgemm (x86) | HistogramObserver (reduce_range=True) | PerChannelMinMaxObserver (default_per_channel_weight_observer) |
qnnpack (arm) | HistogramObserver (reduce_range=False) | MinMaxObserver (default_weight_observer) |
default | MinMaxObserver (default_observer) | MinMaxObserver (default_weight_observer) |
- 准备模型并打印模型:
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
print(prepared_model.graph)
- 模型较准
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
calibrate(prepared_model, data_loader_test) # run calibration on sample data
- 量化模型
quantized_model = convert_fx(prepared_model)
print(quantized_model)
- 对比量化前后,评估量化效果,包括模型大小、性能、时延等