YOLOX剪枝(官网的)
本文主要是根据csdn 爱吃肉的鹏,他的pytorch-yolox剪枝代码所改,非常感谢大佬的帮助,正好是我所需要的,但是我发现B站大佬这个代码训练出来的精度会比官网低10%的map 0.5,在我的Dota-1.5数据集,有一个类总是AP=0,我感觉应该是B站大佬的代码写的不完整,导致推理的精度低,所以我直接用的YOLOX官网,与大佬的版本相结合。又加了一个YOLOX的整体剪枝。
环境
https://github.com/Megvii-BaseDetection/YOLOX.git
这里把YOLOX的代码下载到本地,并根据其安装好环境,这里就不再赘述了。
安装包
pip install torch_pruning==0.2.7
本文实现的功能
1、支持单个卷积剪枝
2、支持网络层剪枝
3、支持模型的整体剪枝
4、剪枝后微调训练
5、修改了激活函数 silu->mish(因为项目用)
数据集格式:采用VOC数据集格式
网络剪枝
参考论文:Pruning Filters for Efficient ConvNets
首先进行模型的正常训练
1、修改代码
网络剪枝
参考论文:Pruning Filters for Efficient ConvNets
在剪枝之前需要通过tools/prunmodel.py save_whole_model(weights_path,num_classes)函数将模型的权重和结构都保存下来
weights_path:权重路径
num_classes:自己数据集的类别数
支持对某个卷积的剪枝 调用Conv_pruning(whole_model_weights):
pruning_idxs=strategy(v,amout=0.4) #0.4是剪枝率,根据需要自己修改,数越大剪枝的越多
对于单独一个卷积的剪枝,需要修改两个地方,这里的卷积层需要打印模型获得,不要自己盲目瞎猜:
支持网络层的剪枝:调用layer_pruning(whole_model_weights):
剪枝以后,会打印模型的参数量变化
支持模型整体剪枝!
python exps/network_slim/main.py
这里面的total_step=1表示只进行一轮剪枝
ch_sparsity=0.5表示剪枝率
我测试了yolox-s
剪枝后的微调训练
在tools/train.py中加入
然后在yolox/core/trainer.py中
这里的model_path换成自己剪枝后保存下来的pth文件就可以。
代码
https://github.com/lizexu123/YOLOX-pruning.git
点个关注把