【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝

文章介绍了一种名为DepGraph的非深度图算法,用于结构化剪枝,适用于多种网络类型。Torch-Pruning框架实现在PyTorch中自动分析网络结构并进行剪枝,降低模型推理成本。通过案例展示了如何使用Torch-Pruning对DeepLabV3+进行剪枝,并进行精度恢复,证明了剪枝后的模型仍保持高效性能。
摘要由CSDN通过智能技术生成

这里提出了一种非深度图算法DepGraph,实现了架构通用的结构化剪枝,适用于CNNs, Transformers, RNNs, GNNsLLM大语言模型等网络。
该算法能够自动地分析复杂的结构耦合,从而正确地移除参数实现网络加速。基于DepGraph算法,我们开发了PyTorch结构化剪枝框架 Torch-Pruning。不同于依赖Masking实现的“模拟剪枝”,该框架能够实际地移除参数和通道,降低模型推理成本。在DepGraph的帮助下,研究者和工程师无需再与复杂的网络结构斗智斗勇,可以轻松完成复杂模型的一键剪枝。
DepGraph算法 论文标题:DepGraph: Towards Any Structural Pruning
DepGraph算法 论文地址:https://arxiv.org/abs/2301.12900
Torch-Pruning工具 github仓库:https://github.com/VainF/Torch-Pruning

一、 下载Torch-Pruning工具

在这里插入图片描述

二、 准备DeepLabV3+代码

这里我使用的是B站UP主bubbliiiing复现的DeepLabV3+代码
github仓库地址:https://github.com/bubbliiiing/deeplabv3-plus-pytorch

在这里插入图片描述

三、 baseline模型训练

在剪枝之前,我们需要正常准备数据,训练出最佳的模型。在剪枝之前,模型的大小为12.9MB,测试效果如下
在这里插入图片描述
在这里插入图片描述

四、 开始剪枝

步骤一、将下载的Torch-Pruning工具库中的torch_pruning文件夹复制到DeepLabV3+代码根目录下
在这里插入图片描述

步骤二、运行下面代码,实现结构化剪枝。

# DeeplabV3 prune code
# 2024/5/1

import torch
import torch_pruning as tp

device = 'cuda'

# Step 0. 加载模型和权重
# 加载模型和预训练权重
model = torch.load('logs/before_prune.pth',map_location=device)
model.eval()
inputs = torch.randn(1, 3, 640, 640).to(device)

# 统计剪枝前参数量
macs, nparams = tp.utils.count_ops_and_params(model, inputs)
print("剪枝前: macs=%d, nparams=%d"%(macs, nparams))

# Step 1. 重要性评判器
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning

# Step 2. 初始化剪枝器
# Step 2.1. head不参与剪枝
# 我这样用的是语义分割模型,cls_conv是里面的结构命名,具体的参数名需要根据自己实际模型中的网络命名进行修改
ignored_layers = []
for name, m in model.named_modules():
    if 'cls_conv' in name:
        ignored_layers.append(m)
# Step 2.2. 初始化剪枝器
iterative_steps = 1 # progressive pruning
prune_rate = 0.5 # 剪枝率
pruner = tp.pruner.MagnitudePruner(model=model,
                                   example_inputs=inputs,
                                   importance=imp,
                                   iterative_steps=iterative_steps,
                                   pruning_ratio=prune_rate,
                                   ignored_layers=ignored_layers,)


# Step 4. 进行剪枝
pruner.step()
# 统计剪枝后参数量
macs, nparams_pruned = tp.utils.count_ops_and_params(model, inputs)
print("剪枝后: macs=%d, nparams=%d"%(macs, nparams_pruned))
params_ratio = nparams_pruned / nparams
print("参数量比: ratio = %f" %(params_ratio))


# Step 6. save
torch.save(model, 'after_pruned.pth') # without .state_dict

在这里插入图片描述

注意,我们需要使用torch.load()torch.save()将模型结构和权重完整的保存下来,不用使用只保留权重(state_dict)的方式。可以看到剪枝后,大小为3.54MB,体积变为了baseline模型的1/4,在不进行精度恢复训练之前,测试一下模型效果,发现完全无效,这是因为模型结构发生了破坏(剪枝),所以下一步还需要精度恢复训练。
在这里插入图片描述
在这里插入图片描述

五、 精度恢复训练

剪枝完后,我们需要使用torch.load的方式加载3.54MB的剪枝模型,然后按照正常的训练流程,对剪枝模型进行精度恢复训练。

model = torch.load(model_path,map_location=device)

训练后,我们再一次测试3.54MB剪枝模型的效果,发现精度已经恢复,且几乎无损,模型大小却已经压缩为原来的1/4
在这里插入图片描述

评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值