【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝

这里提出了一种非深度图算法DepGraph,实现了架构通用的结构化剪枝,适用于CNNs, Transformers, RNNs, GNNsLLM大语言模型等网络。
该算法能够自动地分析复杂的结构耦合,从而正确地移除参数实现网络加速。基于DepGraph算法,我们开发了PyTorch结构化剪枝框架 Torch-Pruning。不同于依赖Masking实现的“模拟剪枝”,该框架能够实际地移除参数和通道,降低模型推理成本。在DepGraph的帮助下,研究者和工程师无需再与复杂的网络结构斗智斗勇,可以轻松完成复杂模型的一键剪枝。
DepGraph算法 论文标题:DepGraph: Towards Any Structural Pruning
DepGraph算法 论文地址:https://arxiv.org/abs/2301.12900
Torch-Pruning工具 github仓库:https://github.com/VainF/Torch-Pruning

一、 下载Torch-Pruning工具

在这里插入图片描述

二、 准备DeepLabV3+代码

这里我使用的是B站UP主bubbliiiing复现的DeepLabV3+代码
github仓库地址:https://github.com/bubbliiiing/deeplabv3-plus-pytorch

在这里插入图片描述

三、 baseline模型训练

在剪枝之前,我们需要正常准备数据,训练出最佳的模型。在剪枝之前,模型的大小为12.9MB,测试效果如下
在这里插入图片描述
在这里插入图片描述

四、 开始剪枝

步骤一、将下载的Torch-Pruning工具库中的torch_pruning文件夹复制到DeepLabV3+代码根目录下
在这里插入图片描述

步骤二、运行下面代码,实现结构化剪枝。

# DeeplabV3 prune code
# 2024/5/1

import torch
import torch_pruning as tp

device = 'cuda'

# Step 0. 加载模型和权重
# 加载模型和预训练权重
model = torch.load('logs/before_prune.pth',map_location=device)
model.eval()
inputs = torch.randn(1, 3, 640, 640).to(device)

# 统计剪枝前参数量
macs, nparams = tp.utils.count_ops_and_params(model, inputs)
print("剪枝前: macs=%d, nparams=%d"%(macs, nparams))

# Step 1. 重要性评判器
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning

# Step 2. 初始化剪枝器
# Step 2.1. head不参与剪枝
# 我这样用的是语义分割模型,cls_conv是里面的结构命名,具体的参数名需要根据自己实际模型中的网络命名进行修改
ignored_layers = []
for name, m in model.named_modules():
    if 'cls_conv' in name:
        ignored_layers.append(m)
# Step 2.2. 初始化剪枝器
iterative_steps = 1 # progressive pruning
prune_rate = 0.5 # 剪枝率
pruner = tp.pruner.MagnitudePruner(model=model,
                                   example_inputs=inputs,
                                   importance=imp,
                                   iterative_steps=iterative_steps,
                                   pruning_ratio=prune_rate,
                                   ignored_layers=ignored_layers,)


# Step 4. 进行剪枝
pruner.step()
# 统计剪枝后参数量
macs, nparams_pruned = tp.utils.count_ops_and_params(model, inputs)
print("剪枝后: macs=%d, nparams=%d"%(macs, nparams_pruned))
params_ratio = nparams_pruned / nparams
print("参数量比: ratio = %f" %(params_ratio))


# Step 6. save
torch.save(model, 'after_pruned.pth') # without .state_dict

在这里插入图片描述

注意,我们需要使用torch.load()torch.save()将模型结构和权重完整的保存下来,不用使用只保留权重(state_dict)的方式。可以看到剪枝后,大小为3.54MB,体积变为了baseline模型的1/4,在不进行精度恢复训练之前,测试一下模型效果,发现完全无效,这是因为模型结构发生了破坏(剪枝),所以下一步还需要精度恢复训练。
在这里插入图片描述
在这里插入图片描述

五、 精度恢复训练

剪枝完后,我们需要使用torch.load的方式加载3.54MB的剪枝模型,然后按照正常的训练流程,对剪枝模型进行精度恢复训练。

model = torch.load(model_path,map_location=device)

训练后,我们再一次测试3.54MB剪枝模型的效果,发现精度已经恢复,且几乎无损,模型大小却已经压缩为原来的1/4
在这里插入图片描述

  • 16
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
近年来,深度学习技术在计算机视觉领域飞速发展,图像语义分割也得到了广泛关注。其中,Google的deeplabv3技术是图像语义分割领域的一项重要成果,凭借其准确性和鲁棒性,在很多领域得到了广泛应用。然而,deeplabv3模型在分割细节信息方面还留有些许不足,为此,一些学者提出了一些改进策略,如下: 1. 空洞卷积增加感受野 deeplabv3使用空洞卷积提取特征,可以有效地扩大感受野,但仍存在一定不足。学者们提出了使用不同大小的空洞卷积以及空洞卷积的多层级级联来增大感受野。这种策略并不能极大地降低分割误差,但是可以更好地提取细节信息。 2. 增加多尺度信息 多尺度信息可以有效地提高模型的性能,在deeplabv3中,学者们提出了引入多尺度信息的模块。该模块括了不同感受野大小的分支网络,这些分支网络并行处理原始图像,提取不同尺度的特征进行融合。这种方法相对效果优异,但是计算量和模型复杂度较大。 3. 全连接条件随机场 deeplabv3在预测像素的分类时并没有考虑像素之间的空间关系,为此,学者们提出使用全连接条件随机场来改进deeplabv3。全连接条件随机场可以结合图像分割结果和空间变化的统计学信息,通过反馈获得更精细的分割结果。这种方法需要更多的计算资源和更长的训练时间,但是效果相对较好。 4. 权重剪枝与网络压缩 deeplabv3在应用中计算量较大,尤其是对于移动设备等较弱算力的设备而言很难实现。为此,学者们提出了使用网络压缩和权重剪枝来减小deeplabv3的计算量和模型复杂度。这种方法能够有效优化deeplabv3的性能和效率。 综合而言,deeplabv3是图像语义分割领域中的一项重要成果,同时也存在一些不足,学者们提出的改进算法不仅可以优化其性能,还可以提高其适用性和实际应用的效果。随着深度学习技术的进一步发展,我们相信deeplabv3的改进与应用将会得到更加广泛而深入的推广和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

m0_51579041

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值