YOLOv8实战-模型剪枝

本文详细描述了如何在YOLOv8模型中应用模型剪枝技术,包括约束训练(使用L1正则化)、剪枝操作以及剪枝后的回调训练,以提高模型效率并进行后续微调。
摘要由CSDN通过智能技术生成

        模型剪枝是用在模型的一种优化技术,旨在减少神经网络中不必要的参数,从而降低模型的复杂性和计算负载,进一步提高模型的效率。

        模型剪枝的流程:约束训练(constained training)、剪枝(prune)、回调训练(finetune)

        本篇主要记录自己YOLOv8模型剪枝的全过程,主要参考:YOLOv8剪枝全过程

目 录

一、约束训练(constrained training)

1、参数设置

2、稀疏训练

二、剪枝(prune)

三、回调训练(finetune)

1、代码修改

2、再训练


一、约束训练(constrained training)

1、参数设置

         设置./ultralytics/cfg/default.yaml中的amp=False

2、稀疏训练

        主要方式:在BN层添加L1正则化

        具体步骤:在./ultralytics/engine/trainer.py中添加以下内容:

                # Backward
                self.scaler.scale(self.loss).backward()

                # ========== added(新增) ==========
                # 1 constrained training
                l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
                for k, m in self.model.named_modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
                        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
                # ========== added(新增) ==========

                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
                if ni - last_opt_step >= self.accumulate:
                    self.optimizer_step()
                    last_opt_step = ni

        然后启动训练(/yolov8/train.py):

from ultralytics import YOLO

model = YOLO('yolov8n.yaml')

results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=300, save=True)

二、剪枝(prune)

        一该部分选用上一步训练得到的模型./runs/detect/train2/weight/last.pt进行剪枝处理。在/yolov8/下新建文件prune.py,具体内容如下:

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect

# Load a model
yolo = YOLO("./runs/detect/train2/weights/last.pt")
model = yolo.model

ws = []
bs = []

for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)


def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta = conv1.bn.bias.data.detach()
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    # print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n

    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    if not isinstance(conv2, list):
        conv2 = [conv2]

    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]


def prune(m1, m2):
    if isinstance(m1, C2f):  # C2f as a top conv
        m1 = m1.cv2

    if not isinstance(m2, list):  # m2 is just one module
        m2 = [m2]

    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1

    prune_conv(m1, m2)


for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)

seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i + 1])

detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

for name, p in yolo.model.named_parameters():
    p.requires_grad = True

yolo.val()  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调

torch.save(yolo.ckpt, "./runs/detect/train2/weights/prune.pt")
print("done")

其中,factor=0.8 表示的是保持率,factor越小,裁剪的就越多,一般不建议裁剪太多。

        运行prune.py,可得到剪枝后的模型prune.pt,保存在./runs/detect/train2/weight/中。同文件夹下,还有last.onnx,可以看到onnx文件的大小比剪枝前变小了,具体结构(onnx模型结构查看)也和剪枝前的onnx相比有了轻微变化。

三、回调训练(finetune)

1、代码修改

        首先,将先前在./ultralytics/engine/trainer.py中添加的L1正则化部分注释掉:

                # Backward
                self.scaler.scale(self.loss).backward()

                # # ========== added(新增) ==========
                # # 1 constrained training
                # l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
                # for k, m in self.model.named_modules():
                #     if isinstance(m, nn.BatchNorm2d):
                #         m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
                #         m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
                # # ========== added(新增) ==========

                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
                if ni - last_opt_step >= self.accumulate:
                    self.optimizer_step()
                    last_opt_step = ni

        然后,在该文件第543行左右添加代码 “self.model = weights” :

    def setup_model(self):
        """Load/create/download model for any task."""
        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return

        model, weights = self.model, None
        ckpt = None
        if str(model).endswith(".pt"):
            weights, ckpt = attempt_load_one_weight(model)
            cfg = weights.yaml
        else:
            cfg = model
        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
        # ========== added(新增) ==========
        # 2 finetune 回调训练
        self.model = weights
        # ========== added(新增) ==========
        return ckpt

2、再训练

         利用已经剪枝好的模型prune.pt,我们再次启动训练(/yolov8/train.py):

from ultralytics import YOLO

model = YOLO('./runs/detect/train5/weights/prune.pt')
results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=100, save=True)

注意,这里把model改成了"prune.pt",而不是原来的"yolov8n.yaml"

        训练后新的模型保存在“./runs/detect/train3/weight/”中。后面可按需要进一步进行模型的推理和部署。

下一篇:YOLOv8实战-模型推理及部署-CSDN博客

  • 16
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
YOLOv8剪枝流程是指对YOLOv8模型进行剪枝的过程。剪枝是一种模型优化技术,旨在减少模型的参数量和计算量,从而提高模型的推理速度和效率。 YOLOv8剪枝流程主要包括以下几个步骤: 1. 模型分析:首先,需要对YOLOv8模型进行分析,了解模型的结构和参数分布情况。可以通过查看模型的网络结构和权重参数来获取这些信息。 2. 重要性评估:接下来,需要对模型中的各个参数进行重要性评估。常用的评估方法包括敏感度分析、梯度信息等。通过评估参数的重要性,可以确定哪些参数对模型性能的影响较小,可以被剪枝掉。 3. 剪枝策略选择:根据参数的重要性评估结果,选择合适的剪枝策略。常见的剪枝策略包括按比例剪枝、按阈值剪枝、通道剪枝等。不同的剪枝策略适用于不同的模型和任务。 4. 剪枝操作:根据选择的剪枝策略,对YOLOv8模型进行剪枝操作。剪枝操作可以通过将参数置零、删除参数等方式实现。剪枝后,模型的参数量和计算量会减少。 5. 微调和压缩:剪枝后的模型可能会出现性能下降的情况,因此需要进行微调和压缩操作。微调是指在剪枝后的模型上进行进一步的训练,以恢复模型的性能。压缩是指对剪枝后的模型进行进一步的压缩,以减小模型的存储空间和计算量。 6. 性能评估:最后,需要对剪枝后的YOLOv8模型进行性能评估,包括模型的推理速度、精度等指标。通过评估剪枝模型的性能,可以判断剪枝效果的好坏。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值