以下内容将详细介绍大模型的剪枝 (Pruning) 技术:从基本概念和常用方法,到大模型(如 Transformer 类模型)特有的应用挑战与实践,再通过一个基于 PyTorch 的可运行示例演示如何对模型进行简单的剪枝操作。
一、什么是剪枝 (Pruning)
剪枝 (Pruning) 指的是在不显著牺牲模型精度的前提下,删除或置零一部分不重要的参数(或结构),从而减少模型大小、降低计算量、加快推理速度的一种模型压缩技术。对于在推理端部署的大规模语言模型(如 GPT、BERT、LLAMA 等),剪枝可以起到减小显存占用、提高吞吐的作用。
1. 剪枝的分类
-
非结构化剪枝(Unstructured Pruning)
- 针对单个权重参数逐个裁剪(将某些权重置零),不关注整个通道/卷积核/注意力头等结构。
- 常见方法:基于权重大小 (Magnitude-based)、基于梯度、基于重要度评分等。
- 优点:灵活,能在参数矩阵中“细粒度”去除不重要的权重,从而能得到较高的剪枝率。
- 缺点:硬件并行库对“稀疏”不友好,实际速度提升往往不明显 unless 硬件和框架对稀疏运算有专门支持。
-
结构化剪枝(Structured Pruning)
- 以通道 (channels)、卷积核 (filters)、注意力头 (attention heads) 或层 (layers) 等为剪枝单位。
- 对 Transformer 类模型,可以剪枝注意力头 (Head Pruning) 或移除整层 (Layer Pruning / Distillation)。
- 优点:容易与现有高效的并行运算结合,可带来实际速度/内存的明显降低;实现也更直观。
- 缺点:粒度较粗,对模型精度影响往往更大,难以达到极高的剪枝率。
-
混合剪枝
- 在不同模块使用不同剪枝策略;或在同一权重矩阵先做结构化,再在剩余部分做非结构化。
2. 剪枝的主要流程
- 确定重要度 (Importance) 评分:如权重绝对值大小、梯度、利用率、敏感性分析等。
- 指定剪枝比例:决定要剪枝多少参数(如 20% 或 50%),或直接指定目标稀疏度。
- 应用剪枝:将较不重要的参数置零 (非结构化) 或去除 (结构化)。
- 可选:再训练或微调 (Fine-tune)
- 修复剪枝造成的性能损失,使模型重新适应被剪枝后的结构。
对于大模型,如 GPT/BERT,会额外面临模型层数多、结构复杂、注意力头/FFN 大等问题,需要选取合适的剪枝策略来平衡“加速 vs. 精度损失”。
二、大模型剪枝的常用思路
-
注意力头剪枝 (Head Pruning)
- Transformer 中的多头注意力 (Multi-Head Attention) 有多个独立的头 (head)。很多研究发现,有些头冗余或功能重复,对最终推断影响小。
- 剪掉不重要的头之后,可减少相应线性投影的参数量和计算量。
- 工具:例如 Hugging Face Transformers 的 head pruning 方法 等。
-
层剪枝 (Layer Pruning)
- 直接减少 Transformer 堆叠层数。例如著名的 DistilBERT 就是在训练过程中知识蒸馏 (Distillation),使模型层数减半,却保持较高性能。
- 缺点:削减层会对复杂性较高的任务影响更大。
-
通道/稀疏剪枝 (Structure in Linear layers)
- 在 MLP / FFN / Linear 层中去除部分通道或隐藏单元,也称为 “filter pruning / column pruning”。
- 有时结合逐通道 (channel-wise) 量化或逐通道正则化来识别不重要通道。
-
非结构化剪枝 (Magnitude-based)
- 最常见也是最容易做的实验:根据权重绝对值大小阈值,直接把最小的 x% 权重置为 0。
- 精度保留往往不错,但速度提升有限(除非有稀疏加速库支持)。
三、实际可运行示例:PyTorch 上的剪枝演示
下面以 BERT-base 或 GPT-2 的小模型为例,做一个非结构化剪枝的简单演示。我们会使用 PyTorch 内置的 torch.nn.utils.prune
工具来实现。
注:此示例仅演示剪枝操作与概念,实际在大模型上常需要更加细致的策略、更多数据微调和更好的硬件支持,才能获得显著的内存或速度收益。
3.1 准备环境
pip install torch transformers
3.2 代码示例
以下示例以 GPT2(gpt2
)为例,做一个简单的一键非结构化剪枝(Magnitude-based)示例。
import os
import torch
import torch.nn.utils.prune as prune
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# 1. 加载 GPT2 小模型(约 117M 参数)
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
# 2. 定义一个简单函数,用来可视化模型稀疏度
def see_weight_sparsity(model):
sum_list = 0
zero_sum = 0
for name, param in model.named_parameters():
if param.requires_grad and "weight" in name:
# 排除 bias、ln、embedding 只看 weight
sum_list += param.nelement()
zero_sum += torch.sum(param==0).item()
print(f"Sparsity: {100*zero_sum/sum_list:.2f}% zeros in weights")
# 3. 剪枝前,查看稀疏度(应该是0%)
see_weight_sparsity(model)
# 4. 对线性层权重进行全局非结构化剪枝
# 比如剪去10%最小绝对值的权重
parameters_to_prune = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.1, # 剪掉10%权重
)
# 5. 再次查看稀疏度
see_weight_sparsity(model)
# 6. 简单测试:推理速度、输出结果差异
test_text = "Hello, I'm a language model"
inputs = tokenizer.encode(test_text, return_tensors="pt")
with torch.no_grad():
original_output = model(inputs)
print("Logits shape:", original_output.logits.shape)
# 如果想后续永久保留零权重,可调用 prune.remove()
# 否则默认是用mask的方式保留原来的param,但可以这样做:
for (module, param_name) in parameters_to_prune:
prune.remove(module, param_name)
运行结果解释
- 初始稀疏度 通常是 0.0%,因为 GPT2 模型权重默认是密集存储。
- 剪枝后,打印出新的稀疏度,例如 ~10.0%。
- 你可以继续加大剪枝比例(如 50%)或多次迭代地裁剪,以观察对模型的影响。
- 实际推理速度 可能不会有显著提升,因为默认的稠密 BLAS 操作库无法利用稀疏 pattern,需要在具备稀疏支持的库或硬件上才能见到加速。
- 若想补偿性能损失,可在剪枝后进行少量 微调 (Fine-tuning),让模型适应被置零的权重。
四、进一步的思考与进阶
-
结合蒸馏 (Distillation):在对大模型进行剪枝后,往往要做一次教师-学生蒸馏,让被剪枝(或结构缩减)的学生模型在大量数据上模仿教师模型的输出,以尽量恢复精度。
-
基于注意力头剪枝
- 对 Transformer 而言,剪枝注意力头经常被视为一种更有针对性的结构化剪枝。
- Hugging Face Transformers 中有一些现成的函数可将某些注意力头设为无效或移除。
-
Movement Pruning / Lottery Ticket
- 更高阶的剪枝策略如 Movement Pruning 在微调过程中动态测量梯度方向,能得到更好的稀疏性与精度。
- Lottery Ticket Hypothesis 提出:在大模型中存在“幸运子网络 (winning subnet)”,能在初始训练阶段就被识别出来并独立训练;对超大模型的有效性尚在研究。
-
部署效率:
- 在 GPU/TPU 上获得真正的推理加速往往依赖结构化剪枝(如整通道或 attention head 剪枝),因为非结构化稀疏会破坏硬件优化。
- 部分框架(如 NVIDIA TensorRT、Intel MKL-Sparse)或自定义稀疏 kernel,才能从非结构化剪枝中获益。
五、总结
- 剪枝 (Pruning) 在大模型领域具有重要意义,能减少模型大小、降低推理成本。但需要在剪枝率与精度损失之间谨慎权衡。
- 非结构化剪枝方法简单,稀疏率可观,但加速收益有限;结构化剪枝更适合实际部署,但需要更加谨慎地设计策略。
- 在实际工业场景中,经常会结合知识蒸馏、量化、混合并行等多种手段,让大模型在推理端保持较好的运行效率。
- 以上示例只是简单展示了 PyTorch 内置的“基于权重大小的全局非结构化剪枝”。更多研究或项目中,剪枝往往与微调 (Finetune) 或 蒸馏 (Distill) 一起进行,以获得最优效果。
参考与延伸
- PyTorch 剪枝文档: PyTorch Pruning Tutorial
- Hugging Face Transformers: Head Pruning Example
- DistilBERT: 论文 “DistilBERT, a distilled version of BERT”
- Movement Pruning: 论文 "Movement Pruning: Adaptive Sparsity by Fine-Tuning"
- Lottery Ticket Hypothesis: 论文 "The Lottery Ticket Hypothesis"
通过以上对大模型剪枝的原理、场景和可运行示例的介绍,希望你能快速上手并在实践中灵活运用、探索不同剪枝策略,从而在保证模型精度的同时有效地减少计算开销与模型大小。
【哈佛博后带小白玩转机器学习】 哔哩哔哩_bilibili
总课时超400+,时长75+小时