从蒸馏到压缩:两种方式训练GPT-2小模型 Student

本文展示如何使用两种方式训练一个 GPT-2 Student 小模型,一种是加载原始 GPT-2 模型直接蒸馏(v1),另一种是使用 GPT2Config 构建结构更小的模型进行压缩蒸馏(v2)。同时提供完整训练流程、代码与对比说明。


📌 什么是模型蒸馏(Distillation)?

模型蒸馏是一种将大模型(Teacher)知识压缩到小模型(Student)的方法:

  • Teacher 输出 logits,作为软标签指导 student 学习
  • 训练目标不再是标签,而是“模仿”大模型的行为
  • 常用于模型压缩、推理加速、边缘部署

🔁 版本对比:v1 vs v2

方面distill_training.py(v1)distill_training_v2.py(v2)
Student 构建方式直接加载 GPT2使用 GPT2Config 构建 6层小模型
Transformer 层数126
模型大小与 teacher 一样小一半
场景适配蒸馏演示模型压缩 + 蒸馏

🧱 步骤通用结构(两个版本都包含)

1️⃣ 加载模型和分词器

teacher_model = GPT2LMHeadModel.from_pretrained("path_to_teacher").eval()
tokenizer = GPT2Tokenizer.from_pretrained("path_to_teacher")

# v1:加载标准 GPT2 作为 student
student_model = GPT2LMHeadModel.from_pretrained("gpt2").train()

# v2:构建 6层 transformer 小模型
config = GPT2Config(n_layer=6, n_embd=768, n_head=12, vocab_size=50257)
student_model = GPT2LMHeadModel(config).train()
  • vocab_size 必须一致(否则 tokenizer 报错)
  • n_layer 越小,模型越轻量,蒸馏越有意义

2️⃣ 构造数据集与 DataLoader

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=64):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        inputs = self.tokenizer(self.texts[idx], return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
        return {k: v.squeeze(0) for k, v in inputs.items()}

训练数据(示例):

train_texts = [
    "Hello world!", "The sky is blue.",
    "AI is changing the world.", ...
]

3️⃣ 蒸馏训练核心逻辑:KL Loss

loss_fn = torch.nn.KLDivLoss(reduction="batchmean")

loss = loss_fn(
    student_logits.log_softmax(dim=-1),
    teacher_logits.softmax(dim=-1)
)
  • 使用 softmax + log_softmax 组合符合 KL 散度要求
  • 不用交叉熵,因为没有 ground truth label,只有 teacher 输出作为 soft target

4️⃣ 梯度更新流程

optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)

for batch in dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    with torch.no_grad():
        teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits

    student_logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits

    loss = loss_fn(
        student_logits.log_softmax(dim=-1),
        teacher_logits.softmax(dim=-1)
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

💾 模型保存

save_path = "./gpt2_student_v2"  # 或 gpt2_student(v1)
student_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

训练完成后即可用于推理,支持使用 Hugging Face 接口加载。


🎯 总结

  • 蒸馏是一种让 student 模型模仿 teacher 的训练方式
  • v1 是结构不变的蒸馏(偏向学习机制)
  • v2 是结构压缩 + 蒸馏(更适合部署)
  • 使用 KL Loss 衡量两个模型输出概率分布的差异

🧭 本文为 GPT-2 蒸馏压缩项目第一篇,共3篇


📌 YoanAILab 技术导航页

💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页

📚 包含内容:

  • 🧠 GPT-2 项目源码(GitHub)
  • ✍️ CSDN 技术专栏合集
  • 💼 知乎转型日志
  • 📖 公众号 YoanAILab 全文合集
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YoanAILab

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值