1. 各种蒸馏方案大全
2. 蒸馏工具
https://github.com/airaria/TextBrewer#quickstart
2.1 蒸馏步骤:
2.2. 方法:看起来比较简单
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
# Show the statistics of model parameters
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)
print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)
# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
# The second and third elements of model outputs are the logits and hidden states
return {'logits': model_outputs[1],
'hidden': model_outputs[2]}
# Training configuration
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
distill_config = DistillationConfig(
intermediate_matches=[
{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
# Build distiller
distiller = GeneralDistiller(
train_config=train_config, distill_config = distill_config,
model_T = teacher_model, model_S = student_model,
adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
# Start!
with distiller:
distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
3. 参考
BERT蒸馏完全指南|原理/技巧/代码:https://mp.weixin.qq.com/s/p0EZ4uFrLBLUuRRNMezGiQ