Deep Compression:Pruning (剪枝模型压缩)

caffe 实现剪枝模型压缩

 根据“没有免费的午餐定律”,想要好机器学习效果,就要设计很深的网络结构,导致model很大,这篇博客讲述是如何在不降低模型识别率的情况下压缩模型。

其实 network pruning 技术已经被广泛应用到CNN模型的压缩中了。 早期的一些工作中,LeCun 用它来减少网络复杂度,从而达到避免 over-fitting 的效果; 近期,其实也就是作者的第一篇网络压缩论文中,通过剪枝达到了 state-of-the-art 的结果,而且没有减少模型的准确率;
这里写图片描述

从上图的左边的pruning阶段可以看出,其过程是:

  1. 正常的训练一个网络;
  2. 把一些权值很小的连接进行剪枝:通过一个阈值来剪枝;
  3. retrain 这个剪完枝的稀疏连接的网络;

如何在caffe 实现剪枝压缩模型

一、创建新的layer

在src/caffe/layers 文件夹中添加 new_layer.cpp
在include/caffe/layers 文件夹中添加 new_layer.hpp 

不需要重新写new_layer.cpp 和 new_layer.hpp,可以复制最接近的、已有的网络层,然后微改代码,达到设计新layer的目的。比如我拷贝conv_layer.cpp和conv_layer.hpp文件,重命名为compress_conv_layer.cpp、compress_conv_layer.hpp。然后在代码里微改以下代码。

1、在compress_conv_layer.cpp添加ComputeBlobMask函数
添加ComputeBlobMask函数

CmpConvolutionLayer类中添加 ComputeBlobMask() 函数,函数的作用是对卷积层的模型参数进行排序,然后根据sparse_ratio进行裁剪,sparse_ratio参数由 prototxt 文件输入。

2、在compress_conv_layer.cpp修改 Forward_cpu 前馈代码

这里写图片描述

3、在compress_conv_layer.cpp修改Backward_cpu 反馈代码

这里写图片描述

***4、在compress_conv_layer.hpp添加以下函数声明

这里写图片描述

二、建立layer外部输入参数响应

1、在src/caffe/proto/caffe.proto文件中修改以下部分:

本示例中只有一个float sparse_ratio参数,而且class CmpConvolutionLayer : public BaseConvolutionLayer, CmpConvolutionLayer(类)是BaseConvolutionLayer(基类)的派生类。所以修改以下部分代码:

这里写图片描述
这里写图片描述

2、在src/layer_factory.cpp 中添加以下代码:

这里写图片描述

3、在src/layers/base_conv_layer.cpp 中的LayerSetUp函数中添加以下代码(我的理解是从prototxt中读取sparse_ratio参数以及初始化mask

这里写图片描述

4、在src/net.cpp 中的CopyTrainedLayersFrom函数中添加以下代码(作用是计算mask掩码,用来剪枝)
这里写图片描述

三、修改prototxt文件

这里写图片描述

  • 修改 ConvolutionCmpConvolution
  • 添加 sparse_ratio: 0.33 (设置稀疏率)

四、finetune 重新训练,得到剪枝后的model

D:/caffe/caffe-windows-ms/Build/x64/Release/caffe.exe train --solver=cifar10_quick_solver_stage_1.prototxt --weights=models/hand_stage_iter_12000.caffemodel

pause

这里写图片描述
这里写图片描述

经过测试识别率下降很少,大概有%1 的精度损失,

hand_stage_0_iter_12000识别率99%
hand_stage_3_iter_1000识别率98%。

五、识别结果展示

这里写图片描述

评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值