YOLOv3-model-pruning
用 YOLOv3 模型在一个开源的人手检测数据集 oxford hand 上做人手检测,并在此基础上做模型剪枝。对于该数据集,对 YOLOv3 进行 channel pruning 之后,模型的参数量、模型大小减少 80% ,FLOPs 降低 70%,前向推断的速度可以达到原来的 200%,同时可以保持 mAP 基本不变(这个效果只是针对该数据集的,不一定能保证在其他数据集上也有同样的效果)。
环境
Python3.6, Pytorch 1.0及以上
YOLOv3 的实现参考了 eriklindernoren 的 PyTorch-YOLOv3 ,因此代码的依赖环境也可以参考其 repo
数据集准备
- 下载widerface数据集,得到压缩文件(提取码: ymx2)
- 将压缩文件解压到 Dataset
-
执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件
剪枝算法介绍
本代码基于论文 Learning Efficient Convolutional Networks Through Network Slimming (ICCV 2017) 进行改进实现的 channel pruning算法,类似的代码实现还有这个 yolov3-network-slimming。原始论文中的算法是针对分类模型的,基于 BN 层的 gamma 系数进行剪枝的。
**注意**
1.训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list index out of range
yolov3-face.cfg可以由 creat_custom_model.sh生成
2.正常训练(Baseline)
python3 train.py --model_def config/yolov3-face.cfg -lr 0.004 --data_config config/widerface.data
3.稀疏化训练
python3 train.py --model_def config/yolov3-face.cfg -sr --s 0.01 --data_config config/widerface.data
#1. 正常训练(Baseline)
python3 train.py --model_def config/yolov3-hand.cfg
# 2.以下只是剪枝算法的大概步骤,具体实现过程中还要做 s 参数的尝试或者需要进行迭代式剪枝等。
# 2.1 进行稀疏化训练
python3 train.py --model_def config/yolov3-hand.cfg -sr --s 0.01
# 2.2 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型
python3 test_prune.py
# 2.3 对剪枝后的模型进行微调
python3 train.py --model_def config/prune_yolov3-hand.cfg -pre checkpoints/prune_yolov3_ckpt.pth
# 3.测试
#python3 test.py --model_def config/prune_yolov3-hand.cfg --weights_path weights/prune_yolov3_ckpt.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01
python3 test.py --model_def config/prune_0.85_yolov3-hand.cfg --weights_path checkpoints/yolov3_ckpt_99_08211153.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01
#==================**************************================================
#==================**************************================================
# 基于wider face数据集进行yolov3剪枝训练步骤
1.执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件
**注意**
训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list in