FastBERT: a Self-distilling BERT with Adaptive Inference Time的理解
文章目录
前言
FastBert的论文地址:https://arxiv.org/pdf/2004.02178.pdf,github开源代码地址:https://github.com/BitVoyage/FastBERT
为了改善bert的推理时间,文章的作者提出了自蒸馏和自适应机制的FastBert 模型,FastBert的整体思路是:论文中是以文本分类的任务实现bert模型蒸馏的:
(1)微调:bert模型后接Classifier层(Teacher Classifier),实现分类模型的微调,Classifier层的结构:Fully-connect layer(768->128)->bert-self-attention(128->128)->Fully-connect layer(128->128)->Fully-connect layer(768->num_class);
(2)自蒸馏:bert encoder的每一层后接Classifier(Student Classifier)层,冻结微调阶段的参数,Student Classifier层学习Teacher Classifier的分布状况,使用KL散度衡量损失;
(3)推理:每层的Student Classifier分类器对样本进行预测,预测结果使用熵来衡量,熵越大,不确定度越大;分类效果:使用speed,代表不确定性的阈值,与推理速度成正比。speed阈值越大,推理速度越快。对于某些样本可能经过很少的层就能预测出来,而最坏的情况经过所有层预测出结果,因而实现推理阶段减少计算量,推理速度加快
一、FastBert模型
FastBERT,a pre-trained model with a sample-wise adaptive mechanism. It can adjust the number of executed layers dynamically to reduce computational steps.This model also has a unique self-distillation process that requires minimal changes to the structure,achieving faster yet as accurate outcomes within a single framework.
作者提出一种具有样本自适应机制的预训练模型,能够动态的调整计算层的数量从而减少计算量。单个框架内实现模型的自蒸馏过程,对模型结构改变很小,速度更快且有准确的结果。
Backbone主干: the embedding layer, the encoder containing stacks of Transformer blocks and the teacher classifier.(bert模型 + teacher classifier层)
Brach分支:指的是所有的Student Classifier层
Classifier层的结构:
Classifier层的代码实现:
class FastBERTClassifier(nn.Module):
def __init__(self, config, op_config):
super(FastBERTClassifier, self).__init__()
cls_hidden_size = op_config["cls_hidden_size"]
num_attention_heads = op_config['cls_num_attention_heads']
num_class = op_config["num_class"]
self.dense_narrow = nn.Linear(config.hidden_size, cls_hidden_size)
self.selfAttention = BERTSelfAttention(config, hidden_size=cls_hidden_size, num_attention_heads=num_attention_heads)
self.dense_prelogits = nn.Linear(cls_hidden_size, cls_hidden_size)
self.dense_logits = nn.Linear(cls_hidden_size, num_class)
def forward(self, hidden_states):
states_output = self.dense_narrow(hidden_states)
states_output = self.selfAttention(states_output, None, use_attention_mask=False)
token_cls_output = states_output[:, 0]
prelogits = self.dense_prelogits(token_cls_output)
logits = self.dense_logits(prelogits)
return logits
FastBertGraph代码实现:整个代码的微调和蒸馏以及推理阶段的核心逻辑代码
class FastBERTGraph(nn.Module):
def __init__(self, bert_config, op_config):
super(FastBERTGraph, self).__init__()
bert_layer = BERTLayer(bert_config)
self.layers = nn.ModuleList([copy.deepcopy(bert_layer) for _ in range(bert_config.num_hidden_layers)])
self.layer_classifier = FastBERTClassifier(bert_config, op_config)
# 分类层的参数字典
self.layer_classifiers = nn.ModuleDict()
for i in range(bert_config.num_hidden_layers - 1):
self.layer_classifiers['branch_classifier_'+str(i)] = copy.deepcopy(self.layer_classifier)
self.layer_classifiers['final_classifier'] = copy.deepcopy(self.layer_classifier)
self.ce_loss_fct = nn.CrossEntropyLoss() # 交叉熵损失
self.num_class = torch.tensor(op_config["num_class"], dtype=torch.float32)
def forward(self, hidden_states, attention_mask, labels=None, inference=False, inference_speed=0.5, training_stage=0):
#-----Inference阶段,第i层student不确定性低则动态提前返回----#
if inference:
uncertain_infos = []
for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
hidden_states = layer_module(hidden_states, attention_mask)
logits = layer_classifier_module(hidden_states)
prob = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1)
uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))
uncertain_infos.append([uncertain, prob])
#提前返回结果
if uncertain < inference_speed:
return prob, i, uncertain_infos
return prob, i, uncertain_infos
#------训练阶段, 第一阶段初始训练, 第二阶段蒸馏训练--------#
else:
#初始训练,和普通训练一致
if training_stage == 0:
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, attention_mask)
logits = self.layer_classifier(hidden_states)
loss = self.ce_loss_fct(logits, labels)
return loss, logits
#蒸馏训练,每层的student和teacher的KL散度作为loss
else:
all_encoder_layers = []
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, attention_mask)
all_encoder_layers.append(hidden_states)
all_logits = []
for encoder_layer, (k, layer_classifier_module) in zip(all_encoder_layers, self.layer_classifiers.items()):
layer_logits = layer_classifier_module(encoder_layer)
all_logits.append(layer_logits)
#NOTE:debug if freezed
#print(self.layer_classifiers['final_classifier'].dense_narrow.weight)
loss = 0.0
teacher_log_prob = F.log_softmax(all_logits[-1], dim=-1)
for student_logits in all_logits[:-1]:
student_prob = F.softmax(student_logits, dim=-1)
student_log_prob = F.log_softmax(student_logits, dim=-1)
uncertain = torch.sum(student_prob * student_log_prob, 1) / (-torch.log(self.num_class))
#print('uncertain:', uncertain[0])
D_kl = torch.sum(student_prob * (student_log_prob - teacher_log_prob), 1)
D_kl = torch.mean(D_kl)
loss += D_kl
return loss, all_logits
1.初始的训练阶段:
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, attention_mask)
logits = self.layer_classifier(hidden_states)
loss = self.ce_loss_fct(logits, labels)
return loss, logits
训练阶段就是bert模型后接分类层进行微调,与之前的模型微调一样
2.自蒸馏阶段
all_encoder_layers = [] # 所有的encoder层
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, attention_mask)
all_encoder_layers.append(hidden_states)
all_logits = [] # Classifier层的概率分布
for encoder_layer, (k, layer_classifier_module) in zip(all_encoder_layers, self.layer_classifiers.items()):
layer_logits = layer_classifier_module(encoder_layer)
all_logits.append(layer_logits)
#NOTE:debug if freezed
#print(self.layer_classifiers['final_classifier'].dense_narrow.weight)
loss = 0.0
# Teacher classifier层的概率分布
teacher_log_prob = F.log_softmax(all_logits[-1], dim=-1)
for student_logits in all_logits[:-1]:
student_prob = F.softmax(student_logits, dim=-1)
student_log_prob = F.log_softmax(student_logits, dim=-1)
# 不确定度的计算
uncertain = torch.sum(student_prob * student_log_prob, 1) / (-torch.log(self.num_class))
#print('uncertain:', uncertain[0])
#KL散度衡量Studnet classifier层和Teacher classifier层的分布差异
D_kl = torch.sum(student_prob * (student_log_prob - teacher_log_prob), 1)
D_kl = torch.mean(D_kl)
loss += D_kl
return loss, all_logits**加粗样式**
冻结backbone的参数,学习student classifier的参数分布:
model = FastBertModel(bert_config, config)
load_saved_model(model, args.save_model_path)
save_model_path_for_train = args.save_model_path_distill
#Freeze Part Model 冻结微调阶段的参数加粗样式
for name, p in model.named_parameters():
if "branch_classifier" not in name:
p.requires_grad = False
KL散度的计算公式
3.推理阶段代码:
uncertain_infos = []
for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
hidden_states = layer_module(hidden_states, attention_mask)
logits = layer_classifier_module(hidden_states)
# Classifier层的分类输出概率
prob = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1) # 对数概率
# 不确定度
uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))
uncertain_infos.append([uncertain, prob])
#提前返回结果 不确定度的判断
if uncertain < inference_speed:
return prob, i, uncertain_infos
return prob, i, uncertain_infos
不确定度的公式:
二、FastBert模型的训练
模型的代码实现参考github:https://github.com/BitVoyage/FastBERT
模型的训练逻辑
1. load 预训练好的bert,分类模型的finetune
2. freeze主干网络和最后层的teacher分类器,每层的子模型拟合teacher分类器(KL散度为loss)
3. inference阶段,根据样本输入,子分类器置信度高则提前返回
训练过程代码逻辑:
if args.run_mode == 'train':
#初始训练 微调阶段
if args.train_stage == 0:
model = FastBertModel.load_pretrained_bert_model(bert_config, config,
pretrained_model_path=config.get("bert_pretrained_model_path"))
save_model_path_for_train = args.save_model_path
#蒸馏训练
elif args.train_stage == 1:
model = FastBertModel(bert_config, config)
load_saved_model(model, args.save_model_path)
save_model_path_for_train = args.save_model_path_distill
#Freeze Part Model 冻结为微调阶段backbone的网络参数
for name, p in model.named_parameters():
if "branch_classifier" not in name:
p.requires_grad = False
logging.info("Main Graph and Teacher Classifier Freezed, Student Classifier will Distilling")
else:
raise RuntimeError('Operation Train Stage(0 or 1) not Legal')
elif args.run_mode == 'eval':
model = FastBertModel(bert_config, config)
load_saved_model(model, args.save_model_path)
else:
raise RuntimeError('Operation Mode not Legal')
其中ps指的输出的概率分布,N为分类标签数量
代码测试流程:
1. 初始训练:
sh run_scripts/script_train_stage0.sh
2. 蒸馏训练:
sh run_scripts/script_train_stage1.sh
**注意** :蒸馏阶段输入数据为无监督数据,可依据需要引入更多数据提升鲁棒性
3. 推理:
sh run_scripts/script_infer.sh
其中 inference_speed参数(0.0~1.0)控制加速程度
4. 部署使用
python3 predict.py
源码实测的结果为:
测试设备GPU,单张显卡
测试集上的infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0364, acc:0.9533, 基准
speed_arg:0.1, time_per_record:0.0330, acc:0.9533, 1.10倍
speed_arg:0.5, time_per_record:0.0223, acc:0.9383, 1.63倍
speed_arg:0.8, time_per_record:0.0171, acc:0.9108, 2.13倍
验证集上infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0365, acc:0.9392, 基准
speed_arg:0.1, time_per_record:0.0332, acc:0.9400, 1.10倍
speed_arg:0.5, time_per_record:0.0237, acc:0.9333, 1.54倍
speed_arg:0.8, time_per_record:0.0176, acc:0.9100, 2.07倍
测试设备GPU,单张显卡:同样步骤重新训练的模型测试
验证集上infer
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.0356, acc:0.9425, 基准
speed_arg:0.1, time_per_record:0.0263, acc:0.9433, 1.35倍
speed_arg:0.5, time_per_record:0.0186, acc:0.9383, 1.91倍
speed_arg:0.8, time_per_record:0.0157, acc:0.9050, 2.26倍
acc的数值与github作者的相似,但是推理速度差异较大,可能与设备有关。
github提供的测试结果:
ChnSentiCorp:
speed_arg:0.0, time_per_record:0.14725032741416672, acc:0.9400, 基准
speed_arg:0.1, time_per_record:0.10302954971909761, acc:0.9420, 1.42倍
speed_arg:0.5, time_per_record:0.03420266199111938, acc:0.9340, 4.29倍
speed_arg:0.8, time_per_record:0.019530397139952513, acc:0.9160, 7.54倍
注:speed=0.1的情况下比基准的准确率还高,是有可能的,正则之类的效应
从作者实测的结果可以看出,蒸馏之后的模型,speed为0.1的时候推理速度提升1.42倍,acc略微提升一点,speed为0.5的时,推理速度提升4.29倍,acc下降0.006个点,随着speed的提高,推理速度加快,acc也会降低。
三、FastBert论文中的测试结果
论文中在6个中文和6个英文语料集上均进行了测试,结果显示:acc基本不变的情况下,speed=0.1时,FastBert有2~5倍的推理速度提升,acc有较小损失的情况下, FastBERT相比 BERT有7到11倍的推理速度提升。
本人在几种数据集上的测试结果,按照github提供的代码,没有修改参数,GPU GTX1080TI上的测试结果:
ChnSentiCorp数据集
acc loss
speed_arg 0.0 time_per_record 0.0286 acc 0.9425 基准
speed_arg 0.1 time_per_record 0.0286 acc 0.9433 1.0
speed_arg 0.5 time_per_record 0.0178 acc 0.9383 1.6
speed_arg 0.8 time_per_record 0.0136 acc 0.9050 2.1
speed_arg 1 time_per_record 0.0049 acc 0.5025 5.8
thnews数据集
acc loss
speed 0 time_pre_record 0.0334 acc 0.9704 基准
speed 0.1 time_pre_record 0.0303 acc 0.9704 1.1
speed 0.5 time_pre_record 0.0268 acc 0.9704 1.2
speed 0.8 time_pre_record 0.0126 acc 0.9696 2.7
speed 1 time_pre_record 0.0070 acc 0.8930 4.8
微博数据集
acc loss
speed 0 time_pre_record 0.0208 acc 0.9788 基准
speed 0.1 time_pre_record 0.0280 acc 0.9788 0.7
speed 0.5 time_pre_record 0.0065 acc 0.9783 3.2
speed 0.8 time_pre_record 0.0058 acc 0.9774 3.6
speed 1 time_pre_record 0.0045 acc 0.9680 6.2
豆瓣数据集
acc loss
speed 0 time_pre_record 0.0283 acc 0.8796 基准
speed 0.1 time_pre_record 0.0240 acc 0.8808 1.2
speed 0.5 time_pre_record 0.0144 acc 0.8693 2.0
speed 1 time_pre_record 0.0045 acc 0.7100 6.3
购物数据
acc loss
speed 0 time_pre_record 0.0299 acc 0.9697 基准
speed 0.1 time_pre_record 0.0157 acc 0.9694 1.9
speed 0.5 time_pre_record 0.0071 acc 0.9602 4.2
speed 1 time_pre_record 0.0045 acc 0.9202 6.6
从测试结果可以看出,不同的中文数据集的蒸馏后推理速度的加速差异较大,同一数据集多次测试推理速度也会有一定的波动,可能与当前设备的状况有关。
测试数据集,下载地址:https://share.weiyun.com/ZctQJP8h或者百度网盘链接:https://pan.baidu.com/s/1MayPlDKJ0_UEZgrdp_1MLg
提取码:x6fw
总结
FastBert在多个数据集上测试结果显示:在很小精度损失的情况下,模型的推理速度均有一定程度的提升。
参考
知乎文章:https://zhuanlan.zhihu.com/p/127869267
github:https://github.com/autoliuweijie/FastBERT