这里提出了一种非深度图算法
DepGraph
,实现了架构通用的结构化剪枝,适用于CNNs
,Transformers
,RNNs
,GNNs
,LLM
大语言模型等网络。
该算法能够自动地分析复杂的结构耦合,从而正确地移除参数实现网络加速。基于DepGraph
算法,我们开发了PyTorch
结构化剪枝框架Torch-Pruning
。不同于依赖Masking实现的“模拟剪枝”,该框架能够实际地移除参数和通道,降低模型推理成本。在DepGraph的帮助下,研究者和工程师无需再与复杂的网络结构斗智斗勇,可以轻松完成复杂模型的一键剪枝。
DepGraph
算法 论文标题:DepGraph: Towards Any Structural Pruning
DepGraph
算法 论文地址:https://arxiv.org/abs/2301.12900
Torch-Pruning
工具github
仓库:https://github.com/VainF/Torch-Pruning
一、 下载Torch-Pruning工具
二、 准备DeepLabV3+代码
这里我使用的是B站UP主bubbliiiing复现的DeepLabV3+
代码
github仓库地址:https://github.com/bubbliiiing/deeplabv3-plus-pytorch
三、 baseline模型训练
在剪枝之前,我们需要正常准备数据,训练出最佳的模型。在剪枝之前,模型的大小为12.9MB
,测试效果如下
四、 开始剪枝
步骤一、将下载的Torch-Pruning
工具库中的torch_pruning
文件夹复制到DeepLabV3+
代码根目录下
步骤二、运行下面代码,实现结构化剪枝。
# DeeplabV3 prune code
# 2024/5/1
import torch
import torch_pruning as tp
device = 'cuda'
# Step 0. 加载模型和权重
# 加载模型和预训练权重
model = torch.load('logs/before_prune.pth',map_location=device)
model.eval()
inputs = torch.randn(1, 3, 640, 640).to(device)
# 统计剪枝前参数量
macs, nparams = tp.utils.count_ops_and_params(model, inputs)
print("剪枝前: macs=%d, nparams=%d"%(macs, nparams))
# Step 1. 重要性评判器
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning
# Step 2. 初始化剪枝器
# Step 2.1. head不参与剪枝
# 我这样用的是语义分割模型,cls_conv是里面的结构命名,具体的参数名需要根据自己实际模型中的网络命名进行修改
ignored_layers = []
for name, m in model.named_modules():
if 'cls_conv' in name:
ignored_layers.append(m)
# Step 2.2. 初始化剪枝器
iterative_steps = 1 # progressive pruning
prune_rate = 0.5 # 剪枝率
pruner = tp.pruner.MagnitudePruner(model=model,
example_inputs=inputs,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=prune_rate,
ignored_layers=ignored_layers,)
# Step 4. 进行剪枝
pruner.step()
# 统计剪枝后参数量
macs, nparams_pruned = tp.utils.count_ops_and_params(model, inputs)
print("剪枝后: macs=%d, nparams=%d"%(macs, nparams_pruned))
params_ratio = nparams_pruned / nparams
print("参数量比: ratio = %f" %(params_ratio))
# Step 6. save
torch.save(model, 'after_pruned.pth') # without .state_dict
注意,我们需要使用torch.load()
,torch.save()
将模型结构和权重完整的保存下来,不用使用只保留权重(state_dict
)的方式。可以看到剪枝后,大小为3.54MB
,体积变为了baseline
模型的1/4
,在不进行精度恢复训练之前,测试一下模型效果,发现完全无效,这是因为模型结构发生了破坏(剪枝),所以下一步还需要精度恢复训练。
五、 精度恢复训练
剪枝完后,我们需要使用torch.load
的方式加载3.54MB
的剪枝模型,然后按照正常的训练流程,对剪枝模型进行精度恢复训练。
model = torch.load(model_path,map_location=device)
训练后,我们再一次测试3.54MB
剪枝模型的效果,发现精度已经恢复,且几乎无损,模型大小却已经压缩为原来的1/4
。