YOLOv8模型的剪枝是一个涉及多个步骤的复杂过程,主要包括稀疏训练、剪枝和微调。以下是YOLOv8剪枝的一般流程:
1. 稀疏训练:
在稀疏训练阶段,通过在训练过程中增加L1正则化约束,筛选出重要的通道。这一步骤可以通过修改`trainer.py`文件,为BN层的权重增加L1正则化项来实现,如在`trainer.py`中添加的代码所示 。
# ========== 新增 ==========
# l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs) # 可调
# for k, m in self.model.named_modules():
# if isinstance(m, nn.BatchNorm2d):
# m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
# m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# ========== 新增 ==========
2. 剪枝:
使用训练后的模型(推荐使用`last.pt`而非`best.pt`)作为剪枝对象。创建`prune.py`文件,编写剪枝代码。剪枝代码需要针对YOLOv8的网络结构进行编写,如果模型有修改,则需要相应地调整剪枝代码 。
3. 检查BN层的bias:
剪枝后,检查BN层的bias是否足够小,如果不满足条件可能需要重新进行稀疏训练 。
4. 设置阈值和剪枝率:
确定全局或局部的剪枝阈值,并设置保持率(例如,`factor = 0.8`表示保留80%的通道)。
5. 剪枝操作:
要注意自己需要剪枝的模型的结构,保证上下层通道数一致,这样才可以进行下一步的验证和训练。
对模型中的卷积层进行剪枝,例如,可以编写函数`prune_conv`来实现对TopConv和BottomConv的剪枝 。
剪枝完成后最好使用下面这行代码验证一下,保证后续微调能够顺利运行
yolo.val() # 剪枝模型进行验证 yolo.val(workers=0)
6. 保存剪枝后的模型:
剪枝完成后,保存模型权重,以便进行进一步的测试或微调 。
7. 微调:
微调可能会出现无法只使用pt模型,即模型训练依然调用了yaml文件的问题,这里需要调整几行代码:
第一步:
# trainer.py中大概545行
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
# ========== 新增该行代码 ==========
self.model = weights
# ========== 新增该行代码 ==========
return ckpt
第二步:
# model.py文件的大概655行
if not args.get("resume"): # manually set model only if not resuming
######################上面两行注释掉,添加下面一行#####
self.trainer.model = self.model.train()
##########################修改####################
# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# self.model = self.trainer.model
第三步:
# 修改loss.py中188行左右
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
# b, a, c = pred_dist.shape # batch, anchors, channels
b, a, c = pred_dist.shape
device = pred_dist.device
self.proj = self.proj.to(device)
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
这样就可以正常微调训练了
剪枝后,模型的精度可能会下降,因此需要进行微调以恢复精度。微调可以在剪枝后直接进行,或者先导出为ONNX文件,然后在Netron中检查剪枝结果后再进行 。
8. 注意事项:
在进行剪枝操作时,需要区分网络结构和网络权重,因为剪枝后的权重文件结构可能与原始的yaml文件不匹配,可能需要对yaml文件进行修改以满足剪枝后的要求 。
9. 未来工作:
可以考虑不剪枝的层不进行约束,对于低于全局阈值的模块可以整个移除,以及考虑保留通道数对硬件加速的影响 。
请注意,剪枝是一个需要细致操作的过程,需要对YOLOv8的网络结构和ONNX模型的操作非常熟悉。实际操作中可能会遇到许多细节问题,需要在实践中不断调整和优化。