FasterTransformer中的BERT模型量化实践指南
概述
本文将详细介绍如何在FasterTransformer框架下对BERT模型进行量化处理,包括训练后量化(PTQ)和量化感知训练(QAT)两种主要方法。通过量化技术,可以显著减少模型大小并提高推理速度,同时尽可能保持模型精度。
环境准备
硬件要求
- 推荐使用4块Tesla V100-SXM2-16GB显卡
- 显存时钟频率设置为877MHz
- 处理器时钟频率设置为1530MHz
软件环境
- 使用TensorFlow 1.15.2版本
- 需要安装FasterTransformer的量化组件
环境设置命令:
pip install ft-tensorflow-quantization/
export TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE=0
数据准备
下载预训练BERT模型
wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip -O uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip -d squad_model
下载SQuAD数据集
mkdir squad_data
wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
wget -P squad_data https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
高精度模型微调
在进行量化前,我们需要先对原始BERT模型进行微调,以获得一个高精度的基准模型。
执行命令:
mpirun -np 4 -H localhost:4 \
--allow-run-as-root -bind-to none -map-by slot \
-x NCCL_DEBUG=INFO \
-x LD_LIBRARY_PATH \
-x PATH -mca pml ob1 -mca btl ^openib \
python run_squad.py \
--bert_config_file=squad_model/bert_config.json \
--vocab_file=squad_model/vocab.txt \
--train_file=squad_data/train-v1.1.json \
--predict_file=squad_data/dev-v1.1.json \
--init_checkpoint=squad_model/bert_model.ckpt \
--output_dir=squad_model/finetuned_base \
--do_train=True \
--do_predict=True \
--if_quant=False \
--train_batch_size=8 \
--learning_rate=1e-5 \
--num_train_epochs=2.0 \
--save_checkpoints_steps 1000 \
--horovod
评估结果通常能达到:
{"exact_match": 82.44, "f1": 89.57}
训练后量化(PTQ)
PTQ是在模型训练完成后进行的量化过程,不需要重新训练模型。
量化模式说明
FasterTransformer支持三种量化模式:
- ft1:轻量级量化,对部分层进行量化
- ft2:中等量化,平衡精度和性能
- ft3:全面量化,最大限度提升性能
执行PTQ
python run_squad.py \
--bert_config_file=squad_model/bert_config.json \
--vocab_file=squad_model/vocab.txt \
--train_file=squad_data/train-v1.1.json \
--predict_file=squad_data/dev-v1.1.json \
--init_checkpoint=squad_model/finetuned_base/model.ckpt-5474 \
--output_dir=squad_model/PTQ_mode_2 \
--do_train=False \
--do_predict=True \
--do_calib=True \
--if_quant=True \
--train_batch_size=16 \
--calib_batch=16 \
--calib_method=percentile \
--percentile=99.999 \
--quant_mode=ft2
典型结果:
- ft1模式:{"exact_match": 81.67, "f1": 88.94}
- ft2模式:{"exact_match": 80.44, "f1": 88.30}
量化感知训练(QAT)
如果PTQ结果不理想,可以采用QAT方法,在训练过程中模拟量化效果。
校准阶段
python run_squad.py \
--bert_config_file=squad_model/bert_config.json \
--vocab_file=squad_model/vocab.txt \
--train_file=squad_data/train-v1.1.json \
--predict_file=squad_data/dev-v1.1.json \
--init_checkpoint=squad_model/bert_model.ckpt \
--output_dir=squad_model/QAT_calibrated_mode_2 \
--do_train=False \
--do_calib=True \
--train_batch_size=16 \
--calib_batch=16 \
--calib_method=percentile \
--percentile=99.99 \
--quant_mode=ft2
训练阶段
mpirun -np 4 -H localhost:4 \
--allow-run-as-root -bind-to none -map-by slot \
-x NCCL_DEBUG=INFO \
-x LD_LIBRARY_PATH \
-x PATH -mca pml ob1 -mca btl ^openib \
python run_squad.py \
--bert_config_file=squad_model/bert_config.json \
--vocab_file=squad_model/vocab.txt \
--train_file=squad_data/train-v1.1.json \
--predict_file=squad_data/dev-v1.1.json \
--init_checkpoint=squad_model/QAT_calibrated_mode_2/model.ckpt-calibrated \
--output_dir=squad_model/QAT_mode_2 \
--do_train=True \
--do_predict=True \
--if_quant=True \
--train_batch_size=8 \
--learning_rate=1e-5 \
--num_train_epochs=2.0 \
--save_checkpoints_steps 1000 \
--quant_mode=ft2 \
--horovod
典型结果:
- ft1模式:{"exact_match": 82.11, "f1": 89.39}
- ft2模式:{"exact_match": 81.74, "f1": 89.12}
结合知识蒸馏的QAT
知识蒸馏可以进一步提升量化模型的性能,通常从PTQ检查点开始。
mpirun -np 4 -H localhost:4 \
--allow-run-as-root -bind-to none -map-by slot \
-x NCCL_DEBUG=INFO \
-x LD_LIBRARY_PATH \
-x PATH -mca pml ob1 -mca btl ^openib \
python run_squad.py \
--bert_config_file=squad_model/bert_config.json \
--vocab_file=squad_model/vocab.txt \
--train_file=squad_data/train-v1.1.json \
--predict_file=squad_data/dev-v1.1.json \
--init_checkpoint=squad_model/PTQ_mode_2_for_KD/model.ckpt-calibrated \
--output_dir=squad_model/QAT_KD_mode_2 \
--do_train=True \
--do_predict=True \
--if_quant=True \
--train_batch_size=8 \
--learning_rate=2e-5 \
--num_train_epochs=10.0 \
--save_checkpoints_steps 1000 \
--quant_mode=ft2 \
--horovod \
--distillation=True \
--teacher=squad_model/finetuned_base/model.ckpt-5474
典型结果:
- ft1模式:{"exact_match": 84.06, "f1": 90.63}
- ft2模式:{"exact_match": 84.02, "f1": 90.56}
总结
本文详细介绍了在FasterTransformer框架下对BERT模型进行量化的完整流程。从实验结果可以看出:
- 单纯的PTQ会导致约1-2%的精度下降
- QAT可以部分恢复精度损失
- 结合知识蒸馏的QAT不仅能恢复精度损失,甚至可能超过原始模型的性能
量化技术的选择应根据实际应用场景在模型大小、推理速度和精度之间进行权衡。对于大多数应用场景,ft2模式提供了良好的平衡点。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考