本文基于
prune_training.py
文件,展示如何使用 PyTorch 对 GPT-2 Student 模型进行 L1 不规则剪枝(Unstructured Pruning),分析剪枝策略、实现代码、效果影响及保存模型的关键细节,帮助你将训练好的模型进一步轻量化。
✂️ 为什么剪枝?
模型训练完成后,仍存在大量“权重占位但几乎不贡献预测”的参数,剪枝可以:
- ✅ 降低显存使用
- ✅ 加快推理速度
- ✅ 保持原模型结构(不影响部署)
本项目采用 PyTorch 自带的 torch.nn.utils.prune
模块完成剪枝。
📁 项目结构
.
├── prune_training.py
├── ../python3_distillation/gpt2_student_v2/ # 蒸馏后小模型
├── ./gpt2_student_v2_pruned/ # 剪枝后输出目录
1️⃣ 剪枝配置与准备
model_path = "../python3_distillation/gpt2_student_v2"
save_path = "./gpt2_student_v2_pruned"
prune_ratio = 0.3 # 剪掉 30% 权重
- 剪枝比例建议为 20~50%
- 剪枝后模型结构不变,仅参数被置 0
2️⃣ 加载模型
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(model_path)
此时模型为完整蒸馏后结构(如 12 层 transformer + full embedding)。
3️⃣ 扫描所有 Linear 层并执行 L1 剪枝
from torch.nn.utils import prune
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name="weight", amount=prune_ratio)
prune.remove(module, "weight") # 彻底置零,永久保存
说明:
操作 | 含义 |
---|---|
l1_unstructured | 按 L1 范数排序剪掉绝对值最小的参数 |
name="weight" | 仅对 Linear 层中的权重参数剪枝 |
prune.remove() | 删除掩码 mask ,将剪枝“写死”进参数矩阵 |
⚠️ 不调用 remove()
则权重仍带有 mask,保存模型后会报错。
4️⃣ 模型保存
model.save_pretrained(save_path)
输出结构为标准 Hugging Face 格式(可被推理脚本和 API 直接加载):
gpt2_student_v2_pruned/
├── config.json
├── pytorch_model.bin
├── ...
5️⃣ 整体流程回顾
# 1. 创建输出目录
os.makedirs(save_path, exist_ok=True)
# 2. 加载模型
model = GPT2LMHeadModel.from_pretrained(model_path)
# 3. 执行 L1 剪枝 + 清除掩码
for module in model.modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.remove(module, "weight")
# 4. 保存模型
model.save_pretrained(save_path)
✅ 效果展示(剪枝前后对比)
- 模型结构不变
- 参数部分为 0(剪掉的连接)
- 推理函数与原始模型完全兼容(见前几篇)
❗ 注意事项
项目 | 是否重要 | 建议说明 |
---|---|---|
是否必须 prune.remove() | 是 | 否则保存模型后加载将报错 |
是否影响结构 | 否 | 仍是 GPT2LMHeadModel 结构 |
是否影响 tokenizer | 否 | tokenizer 与模型剪枝无关 |
剪枝比例是否越高越好 | 否 | 推荐 20~50%,过高影响性能 |
📌 总结
- 本文展示了如何对 GPT-2 小模型进行结构不变的 L1 剪枝
- 剪枝操作基于
torch.nn.utils.prune
,简单可靠 - 剪枝模型可无缝应用于原推理服务和 API 封装流程
🧭 本系列 GPT-2 模型剪枝部署项目系列四部曲
- 🚀 第1篇:GPT-2 小模型剪枝实战:L1 Unstructured 剪枝策略与实现详解
- 🌐 第2篇:GPT-2 剪枝模型推理函数封装实战:输入输出结构与结果解析
- 🧠 第3篇:GPT-2 剪枝前后性能对比实测:加速效果与输出一致性全分析
- 💼 第4篇:GPT-2 Student 模型剪枝部署实战:Flask 接口封装与服务调用指南
📌 YoanAILab 技术导航页
💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页
📚 包含内容:
- 🧠 GPT-2 项目源码(GitHub)
- ✍️ CSDN 技术专栏合集
- 💼 知乎转型日志
- 📖 公众号 YoanAILab 全文合集