以下内容将全面剖析大模型的知识蒸馏(Distillation)技术,包括其基本原理、常见方法、在大规模模型(如 GPT/BERT/LLM)上的应用场景与挑战,以及可运行的 PyTorch 示例。知识蒸馏是近年来模型压缩的重要手段之一,对大语言模型的实际部署有显著帮助。
一、什么是知识蒸馏(Distillation)
知识蒸馏(Knowledge Distillation) 是指通过训练一个规模更小但性能不下降太多的学生模型(Student),来模仿或逼近一个规模更大、精度更高的教师模型(Teacher)的行为,从而达到模型压缩、加速推理的目的。
具体而言:
- 我们有一个已经训练好的 大模型(Teacher)。
- 我们想得到一个 小模型(Student),具备接近 Teacher 的性能但拥有更少的参数量和更快的推理速度。
- 在训练 Student 的过程中,不仅使用传统的标签/真值 (Ground Truth) 监督,还使用来自 Teacher 的输出(如 logits、embedding、注意力分布等)作为“软标签 (Soft Targets)”或额外监督信号,让学生更好地学习 Teacher 的知识。
二、知识蒸馏的基本原理与流程
-
教师模型(Teacher Model)
- 通常是一个在大规模数据上训练好的高精度模型;比如 GPT-2 Large、BERT-Large、LLaMA-2-70B 等。
-
学生模型(Student Model)
- 结构相似或简化版本:如把 BERT-Large 减少层数得到 BERT-Small,或减少隐藏单元/注意力头得到 TinyGPT。
- 也可能是另一种模型结构,但常见做法是保留类似的骨干(Backbone),以便能够尽可能保留原模型优势。
-
蒸馏目标函数
4. 训练过程
- 准备与教师一致的训练/微调数据,或者使用无标注的大规模文本来做蒸馏(称之为自监督蒸馏)。
- 前向:将输入送入教师和学生,得到对应输出(logits、内部层表示等)。
- 计算蒸馏损失
- 反向传播,更新学生模型参数。教师模型通常是冻结的(不训练)。
在大模型场景下,蒸馏最典型的例子是 DistilBERT:将一个 12 层的 BERT-base 蒸馏成一个 6 层的学生模型。它能在保留 97% 性能的同时,模型大小缩小约 40%,推理速度提升 60%。
三、大模型知识蒸馏的常见方法
1. Logits 蒸馏 (Logits Distillation)
最原始也是最常见的蒸馏方式,只关心模型输出层(logits)的模仿。
- Soft label:教师的输出分布(softmax 后的概率分布)提供了比真实标签更丰富的信息,例如哪几个类别更相似。
2. 特征表示蒸馏 (Feature Map Distillation)
在 Transformer 中,不只蒸馏最后的 logits,还要蒸馏中间层的隐状态(Hidden States)或注意力矩阵 (Attention Maps),例如:
- Hidden State Matching:在对应层对齐学生和教师的 hidden states,通过最小化 MSE 或 L1 或 cos-sim 来保留语义信息。
- Attention Head Distillation:在对应层让学生的注意力分布模仿教师的注意力分布。
这种方法能够让学生在层间细节上更接近教师,适合结构相似的模型(如小号 BERT <-> BERT-base)。
3. 多任务蒸馏 (Multi-Task Distillation)
当教师模型是一个多任务或大规模预训练模型时,可以在多个任务数据/多语言数据上进行联合蒸馏,让学生继承教师在不同任务/语言上的知识。
4. Progressive Distillation / Layer-wise Distillation
若学生层数远少于教师层数,则可采用分层逐步蒸馏的策略:
- 先对齐学生第 1 层和教师第 2 层,再对齐学生第 2 层与教师第 4 层…
- 或者预先划分映射关系,然后依次蒸馏;可以让学生更加稳定地学到教师的表征。
5. Prompt Distillation
大模型在指令微调或对话场景中,也可把教师的回答作为一个“软目标”,让学生学习如何在相同指令下进行回答,使学生具备类似的对话能力,但规模更小。
四、大模型知识蒸馏的主要挑战
- 模型规模与显存:教师本身可能非常大(上百亿参数)。对齐中间表示需要显著的 GPU 资源来同时运行 Teacher + Student。
- 训练数据:需要足够多且覆盖丰富领域的数据,才能让学生获得与教师类似的泛化能力。
- 结构差异:当 Student 的结构与 Teacher 出入较大(如层数相差太多),Feature 蒸馏的层对齐会比较复杂。
- 蒸馏时间:有时为了让学生逼近教师,需要花费大量蒸馏训练步骤。
- 多头注意力蒸馏:如果学生减少了头数,如何对齐与教师的注意力分布是个设计难点。
五、大模型知识蒸馏的应用场景
- 服务器端推理加速:如将 GPT-3.5 / LLaMA-2-70B 蒸馏到一个 7B 或 13B 规模的学生,显著降低推理延迟。
- 边缘/移动端部署:在算力、内存有限的设备上,需要一个小而精的模型。教师可以是 BERT-base,学生可以是 MobileBERT 或 TinyBERT。
- 多语言或多任务:大模型往往在多语言、多任务都有良好表现,学生可以通过蒸馏继承这个多任务能力。
- 知识迁移:教师在私有数据或增量数据上训练后,通过蒸馏的方式把新知识迁移到学生模型中,避免学生从零开始大规模训练。
六、可运行示例:蒸馏一个小型 GPT2 模型
下面示例演示 Logits 蒸馏 的流程:
- 教师模型:
gpt2-medium
- 学生模型:
gpt2
(或者一个自定义的小 GPT2 variant) - 任务:对语言建模 (Language Modeling) 数据进行蒸馏,让学生尽量学到教师的分布。
环境准备
- Python 3.7+
pip install torch transformers datasets accelerate
(保证与硬件/系统兼容)
6.1 加载模型、数据
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
from datasets import load_dataset
# 设置模型名称
teacher_model_name = "gpt2-medium"
student_model_name = "gpt2" # 也可自定义更小的模型结构
# 加载教师和学生模型
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = GPT2LMHeadModel.from_pretrained(teacher_model_name)
teacher_model.eval() # 推理/评估模式
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
student_model = GPT2LMHeadModel.from_pretrained(student_model_name)
# 简单数据集:使用 wikitext-2 做语言建模蒸馏演示
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')
6.2 构建 DataLoader
def tokenize_fn(examples):
return teacher_tokenizer(examples["text"], truncation=True, max_length=128)
tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
# PyTorch DataLoader
def collate_fn(batch):
# 这里直接使用 teacher_tokenizer 的 pad 方法,也可用 student_tokenizer
return teacher_tokenizer.pad(batch, return_tensors="pt")
from torch.utils.data import DataLoader
train_loader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
6.3 定义蒸馏损失
我们做最简单的logits蒸馏:
- 取教师模型的 logits (teacher_logits),学生模型的 logits (student_logits);
- 用 KL 散度或 MSE 让学生去模仿教师 logits 分布。
另外,还可以加一个常规的语言模型 (LM) 交叉熵损失,帮助学生不偏离真实任务。
import torch.nn.functional as F
def distillation_loss_function(teacher_logits, student_logits,
labels,
alpha=0.5, temperature=2.0):
"""
teacher_logits, student_logits: (batch_size, seq_len, vocab_size)
labels: (batch_size, seq_len)
alpha: 权重,平衡真实任务损失 与 蒸馏损失
temperature: 蒸馏温度
返回: total_loss
"""
# 1) LM 真实标签交叉熵
# 让学生在真实标签上也保持一定的准确度
# -100 表示填充位置不计算loss
lm_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
# 2) 蒸馏损失 (KL 散度)
# 对 teacher / student 的 logits 做 softmax with temperature
# p(t) = softmax(teacher_logits / T)
# q(s) = softmax(student_logits / T)
teacher_probs = F.log_softmax(teacher_logits / temperature, dim=-1)
student_probs = F.log_softmax(student_logits / temperature, dim=-1)
distill_loss = F.kl_div(
student_probs,
teacher_probs.exp(), # kl_div 需要 target 是概率分布 (非 log)
reduction='batchmean'
) * (temperature**2)
total_loss = alpha * lm_loss + (1 - alpha) * distill_loss
return total_loss, lm_loss, distill_loss
6.4 训练循环
import torch.optim as optim
# 冻结教师模型,不参与训练
for param in teacher_model.parameters():
param.requires_grad = False
optimizer = optim.AdamW(student_model.parameters(), lr=1e-5)
num_epochs = 1 # 简单跑1轮演示
alpha = 0.5
temperature = 2.0
student_model.train()
for epoch in range(num_epochs):
total_loss_val = 0.0
for step, batch in enumerate(train_loader):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
with torch.no_grad():
teacher_out = teacher_model(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_out.logits
student_out = student_model(input_ids, attention_mask=attention_mask)
student_logits = student_out.logits
# labels 用来计算学生的 LM 任务损失
labels = input_ids.clone()
# 也可以把 padding位置设为 -100
labels[labels==teacher_tokenizer.pad_token_id] = -100
loss, lm_loss, distill_loss = distillation_loss_function(
teacher_logits, student_logits,
labels,
alpha=alpha, temperature=temperature
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_val += loss.item()
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}, Loss {loss.item():.4f}, LM {lm_loss.item():.4f}, KD {distill_loss.item():.4f}")
avg_loss = total_loss_val / (step+1)
print(f"Epoch {epoch} finished, avg loss = {avg_loss:.4f}")
以上仅是非常简化的示例:
- 未使用加速/混合精度/分布式并行等;
- 数据集只跑了一小批量;
- 蒸馏超参 (α,T\alpha, T) 需调参。
在真实场景中,需要更大的数据、更长的训练和更充分的超参搜索。
6.5 评估与使用
在结束训练后,即得到经过蒸馏的 student_model
。可对比原始 gpt2
或 gpt2-medium
在验证集上的困惑度 (Perplexity) 或其他指标,并观察模型大小、推理速度提升等。
七、Distillation 的扩展与实践经验
- 与剪枝/量化的结合
- 先剪枝再蒸馏;或蒸馏后量化;或多种技巧并用。可进一步降低模型复杂度。
- 分层蒸馏
- 如果教师比学生多 2 倍的层数,可映射每个学生层到教师的某个(或多个)层。对齐隐藏状态/注意力分布来蒸馏,会获得更好效果。
- 在线蒸馏(Online Distillation)
- 边训练教师边蒸馏给学生。多用于自监督预训练时,让学生同步学到最新的教师知识。
- 提示工程 (Prompt Engineering) + 蒸馏
- 大模型在 Prompt 场景下如何回答,可以让学生通过蒸馏保留教师的“对话风格”或“知识点”。
- 区分度损失
- 除了让学生模仿教师的输出,也可以考虑保留教师在不同输入间的差异(以免学生学得太死板)。
八、总结
- 知识蒸馏 (Distillation) 是利用训练好的大模型 (Teacher) 来指导小模型 (Student) 学习的过程,能在不显著牺牲性能的前提下大幅减少模型规模和推理开销。
- 常见方法包括Logits 蒸馏、特征蒸馏、注意力分布蒸馏等。对Transformer 等深层大模型,蒸馏经常与层数精简、参数精简相结合。
- 实际上,典型的DistilBERT 就是成功案例:6层学生在保留 ~97% BERT-base 精度的同时,模型变小 40% 且推理速度提升 60%。
- 在超大模型(十亿级参数以上)的场景下,蒸馏可以极大缓解高额的存储和推理成本,但需在数据量、训练时间、教师-学生结构对齐等方面投入更多设计。
- 蒸馏往往与剪枝、量化、模型并行等技术一起配合,形成系统化的模型压缩与加速方案。
参考与延伸
- 原始知识蒸馏论文: “Distilling the Knowledge in a Neural Network” (Hinton et al., 2015)
- DistilBERT 论文: “DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter” (Sanh et al., 2019)
- TinyBERT: “TinyBERT: Distilling BERT for Natural Language Understanding”
- Hugging Face Transformers: 提供了一些蒸馏范例与脚本
- 脚本式蒸馏:各大模型库(Megatron-LM、DeepSpeed)也提供了对超大模型的蒸馏支持,可结合分布式训练策略
- 多视角蒸馏: 既蒸馏 logits,也蒸馏隐藏层特征、注意力、梯度,甚至任务级别的表示
通过上述介绍与示例,相信你对大模型的知识蒸馏有了更系统深入的理解,并能在实际开发中利用蒸馏技术来压缩模型、加速推理,并在小模型与大模型性能之间找到适当的平衡。
【哈佛博后带小白玩转机器学习】 哔哩哔哩_bilibili
总课时超400+,时长75+小时