知识蒸馏(Knowledge Distillation,KD)是一种将 大模型(Teacher) 的知识迁移到 小模型(Student) 的方法,能在 保持较高精度的同时减少计算资源消耗。企业可以用 Hugging Face 的工具 高效地训练小模型,以加速推理并降低成本。
1. 知识蒸馏的核心思想
知识蒸馏的目标是让 小模型(Student) 学习 大模型(Teacher) 的行为,通常采用以下策略:
-
软目标(Soft Targets):使用大模型的 softmax 预测值(概率分布)作为小模型的学习目标,而不是硬标签(one-hot)。
-
隐藏层匹配(Intermediate Feature Matching):让小模型的中间层特征与大模型尽可能相似。
-
辅助损失(Distillation Loss):引入
KL 散度(Kullback-Leibler Divergence)
让小模型更好地模仿大模型。
2. 适用场景
-
模型压缩:如
bert-large
→bert-base
或bert-mini
-
降低推理延迟:如
GPT-3
→DistilGPT-2
-
边缘设备部署:在 移动设备/嵌入式设备 上运行更轻量的模型
-
自定义数据蒸馏:使用企业数据让小模型更贴近业务需求
3. 知识蒸馏的方法
方法 1:基于 Soft Targets(Logits 蒸馏)
让 小模型学习大模型的输出概率分布,核心损失函数:
其中:
-
y_true
:真实标签 -
y_teacher
:大模型输出的概率分布 -
y_student
:小模型输出的概率分布 -
α
:权重因子(控制知识蒸馏损失与标准交叉熵损失的平衡)
方法 2:中间层对齐(Feature Matching)
匹配 Teacher 和 Student 的中间层特征,计算 L2 Loss
让小模型的特征表征接近大模型。
方法 3:组合多种蒸馏策略
综合 Soft Targets + Feature Matching + Attention Matching,让小模型全方位学习大模型的知识。
4. Hugging Face 实现知识蒸馏
(1)环境安装
pip install transformers datasets torch accelerate
(2)加载 Teacher 和 Student 模型
假设我们使用 bert-base-uncased
作为 Teacher,distilbert-base-uncased
作为 Student:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
teacher_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
student_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
(3)数据准备
加载 Hugging Face imdb
电影评论数据集:
from datasets import load_dataset
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
预处理文本:
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
train_dataset = train_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.map(preprocess_function, batched=True)
(4)定义知识蒸馏损失
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
"""
计算知识蒸馏损失:
- student_logits: 学生模型的预测值
- teacher_logits: 老师模型的预测值
- labels: 真实标签
- temperature: 蒸馏温度(越高,Soft Targets 分布越平滑)
- alpha: 知识蒸馏损失的权重
"""
# KL 散度损失(使用 Soft Targets)
kl_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1),
reduction="batchmean"
) * (temperature ** 2)
# 交叉熵损失(使用真实标签)
ce_loss = F.cross_entropy(student_logits, labels)
# 综合损失
return alpha * kl_loss + (1 - alpha) * ce_loss
(5)训练蒸馏模型
from torch.utils.data import DataLoader
from transformers import AdamW
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)
student_model.to(device)
optimizer = AdamW(student_model.parameters(), lr=5e-5)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
num_epochs = 3
for epoch in range(num_epochs):
student_model.train()
teacher_model.eval()
total_loss = 0
for batch in train_dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
with torch.no_grad():
teacher_logits = teacher_model(input_ids, attention_mask=attention_mask).logits
student_logits = student_model(input_ids, attention_mask=attention_mask).logits
loss = distillation_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_dataloader)}")
(6)评估 Student 模型
from sklearn.metrics import accuracy_score
def evaluate(model, dataset):
model.eval()
predictions, labels = [], []
for batch in DataLoader(dataset, batch_size=8):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
label = batch["label"]
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
preds = torch.argmax(logits, dim=-1).cpu().numpy()
predictions.extend(preds)
labels.extend(label.numpy())
return accuracy_score(labels, predictions)
student_acc = evaluate(student_model, test_dataset)
teacher_acc = evaluate(teacher_model, test_dataset)
print(f"Teacher Accuracy: {teacher_acc}")
print(f"Student Accuracy: {student_acc}")
5. 进阶优化
(1)蒸馏不同的 Student 模型
-
DistilBERT
(6 层) → 适合 NLP 任务 -
TinyBERT
(4 层) → 更小更快 -
MobileBERT
(量化压缩) → 适合移动设备
(2)使用 Hugging Face 的 DistilBERT 预训练
如果不想手动实现蒸馏,可以直接使用:
from transformers import DistilBertForSequenceClassification
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
(3)混合蒸馏
可以结合 隐藏层蒸馏(Feature Matching)+ Soft Targets,提高蒸馏效果。
6. 总结
-
选择 大模型(Teacher) 和 小模型(Student)。
-
让小模型学习 大模型的 Soft Targets(概率分布),并使用 KL 散度损失 进行优化。
-
训练后评估小模型的性能,并与大模型对比。
-
采用
DistilBERT
、TinyBERT
或MobileBERT
进一步优化。
这样,企业可以在 减少计算资源的同时保持模型精度,实现高效推理 🚀!