本文展示如何使用两种方式训练一个 GPT-2 Student 小模型,一种是加载原始 GPT-2 模型直接蒸馏(
v1
),另一种是使用GPT2Config
构建结构更小的模型进行压缩蒸馏(v2
)。同时提供完整训练流程、代码与对比说明。
📌 什么是模型蒸馏(Distillation)?
模型蒸馏是一种将大模型(Teacher)知识压缩到小模型(Student)的方法:
- Teacher 输出 logits,作为软标签指导 student 学习
- 训练目标不再是标签,而是“模仿”大模型的行为
- 常用于模型压缩、推理加速、边缘部署
🔁 版本对比:v1 vs v2
方面 | distill_training.py(v1) | distill_training_v2.py(v2) |
---|---|---|
Student 构建方式 | 直接加载 GPT2 | 使用 GPT2Config 构建 6层小模型 |
Transformer 层数 | 12 | 6 |
模型大小 | 与 teacher 一样 | 小一半 |
场景适配 | 蒸馏演示 | 模型压缩 + 蒸馏 |
🧱 步骤通用结构(两个版本都包含)
1️⃣ 加载模型和分词器
teacher_model = GPT2LMHeadModel.from_pretrained("path_to_teacher").eval()
tokenizer = GPT2Tokenizer.from_pretrained("path_to_teacher")
# v1:加载标准 GPT2 作为 student
student_model = GPT2LMHeadModel.from_pretrained("gpt2").train()
# v2:构建 6层 transformer 小模型
config = GPT2Config(n_layer=6, n_embd=768, n_head=12, vocab_size=50257)
student_model = GPT2LMHeadModel(config).train()
- vocab_size 必须一致(否则 tokenizer 报错)
- n_layer 越小,模型越轻量,蒸馏越有意义
2️⃣ 构造数据集与 DataLoader
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=64):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __getitem__(self, idx):
inputs = self.tokenizer(self.texts[idx], return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
return {k: v.squeeze(0) for k, v in inputs.items()}
训练数据(示例):
train_texts = [
"Hello world!", "The sky is blue.",
"AI is changing the world.", ...
]
3️⃣ 蒸馏训练核心逻辑:KL Loss
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")
loss = loss_fn(
student_logits.log_softmax(dim=-1),
teacher_logits.softmax(dim=-1)
)
- 使用 softmax + log_softmax 组合符合 KL 散度要求
- 不用交叉熵,因为没有 ground truth label,只有 teacher 输出作为 soft target
4️⃣ 梯度更新流程
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
with torch.no_grad():
teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
student_logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
loss = loss_fn(
student_logits.log_softmax(dim=-1),
teacher_logits.softmax(dim=-1)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
💾 模型保存
save_path = "./gpt2_student_v2" # 或 gpt2_student(v1)
student_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
训练完成后即可用于推理,支持使用 Hugging Face 接口加载。
🎯 总结
- 蒸馏是一种让 student 模型模仿 teacher 的训练方式
v1
是结构不变的蒸馏(偏向学习机制)v2
是结构压缩 + 蒸馏(更适合部署)- 使用 KL Loss 衡量两个模型输出概率分布的差异
🧭 本文为 GPT-2 蒸馏压缩项目第一篇,共3篇
- 🧩 第一篇:从蒸馏到压缩:两种方式训练GPT-2小模型 Student
- 🚀 第二篇:GPT-2 蒸馏模型推理实战:标准 Student vs 压缩 Student 的调用对比
- 🌐 第三篇:GPT-2 蒸馏小模型部署实战:Flask 封装推理接口与网页调用演示
📌 YoanAILab 技术导航页
💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页
📚 包含内容:
- 🧠 GPT-2 项目源码(GitHub)
- ✍️ CSDN 技术专栏合集
- 💼 知乎转型日志
- 📖 公众号 YoanAILab 全文合集