项目结构:
1.正常训练:
train.py
获得best.pt
剪枝有三个过程:
(1)稀疏训练
稀疏训练有两种: 1)mccp 2)ccp
train_convbn.py --weights weights/best.pt
获得如adam-p-2.0-0.001-conv-0.0001-200-s/best.pt
(2) 剪枝
8x_prune.py --weights runs/person_train/adam-p-2.4-0.001-conv-0.0001-200-s/weights/last.pt --percent 0.8
获得8x_0.8.pt
(3)微调
蒸馏只需要蒸馏训练就行。
如果如只蒸馏 不剪枝, 就运行train_distillation.py
如果是剪枝后的模型,就运行distill_finetune.py
train_distillation.py --weights 学生模型 --t_weights 教师模型 --temperature 温度系数 --dist 开启特征蒸馏
--d_feature 开启中间特征蒸馏 --layers 选择蒸馏的中间层
蒸馏的超参数有:
温度系数 --temperature 4
蒸馏的损失函数dist_loss 可以选KD散度损失等。--dist_loss kl
是否进行中间特征蒸馏,选择中间哪些层。--d_feature --layers 1 2 3