yolov11剪枝

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os


# os.environ["CUDA_VISIBLE_DEVICES"] = "2"


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## Normal Pruning
        gamma = conv1.bn.weight.data.detach()
        beta = conv1.bn.bias.data.detach()

        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n

        if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
            proto = conv2.pop()
            proto.cv1.conv.in_channels = n
            proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## Regular Pruning
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            if isinstance(item, Sequential):
                conv1 = item[0]
                conv = item[1].conv
                conv1.conv.in_channels = n
                conv1.conv.out_channels = n
                conv1.conv.groups = n
                conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
                conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
                conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
                conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
                conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
                conv1.bn.num_features = n
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

    def prune(self, m1, m2):
        if isinstance(m1, C3k2):  # C3k2 as a top conv
            m1 = m1.cv2
        if isinstance(m1, Sequential):
            m1 = m1[1]
        if not isinstance(m2, list):  # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C3k2) or isinstance(item, SPPF):
                m2[i] = item.cv1

        self.prune_conv(m1, m2)


def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.65)  # 这里的0.8为剪枝率。

    ### 1. 剪枝C3k2 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3, 5, 7, 8]:
        pruning.prune(seq[i], seq[i + 1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
    detect: Detect = seq[-1]
    proto = detect.proto
    last_inputs = [seq[16], seq[19], seq[22]]
    colasts = [seq[17], seq[20], None]
    for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):
        if idx == 0:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])
        else:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])
        pruning.prune(cv4[0], cv4[1])
        pruning.prune(cv4[1], cv4[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    yolo.val(data='data.yaml', batch=2, device=0, workers=0)
    torch.save(yolo.ckpt, savepath)



if __name__ == "__main__":
    modelpath = "runs/segment/Constraint/weights/best.pt"
    savepath = "runs/segment/Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)

YOLOv11剪枝蒸馏是一种用于优化YOLO(You Only Look Once)目标检测模型的技术。YOLO是一种实时目标检测算法,广泛应用于计算机视觉领域。剪枝和蒸馏是两种常见的模型压缩和优化技术,用于减少模型的计算复杂度和参数数量,从而提高模型的推理速度和效率。 ### YOLOv11剪枝 剪枝是一种通过移除模型中不重要的权重或神经元来减小模型大小的技术。对于YOLOv11剪枝可以通过以下步骤实现: 1. **训练原始模型**:首先在训练数据集上训练一个完整的YOLOv11模型。 2. **权重评估**:评估每个权重的重要性,通常通过计算梯度的绝对值或其他重要性指标。 3. **剪枝**:根据重要性指标移除不重要的权重或神经元。 4. **微调**:对剪枝后的模型进行微调,以恢复部分由于剪枝而丢失的精度。 ### YOLOv11蒸馏 蒸馏是一种将大型复杂模型(教师模型)的知识转移到小型简单模型(学生模型)中的技术。对于YOLOv11,蒸馏可以通过以下步骤实现: 1. **训练教师模型**:首先训练一个高性能的YOLOv11教师模型。 2. **训练学生模型**:使用教师模型的输出来训练一个较小的YOLOv11学生模型。学生模型不仅学习真实标签,还学习教师模型的输出。 3. **联合训练**:在训练过程中,学生模型和教师模型联合优化,以最小化学生模型与教师模型输出之间的差异。 ### 剪枝和蒸馏的结合 将剪枝和蒸馏结合使用,可以进一步优化YOLOv11模型。具体步骤如下: 1. **剪枝**:对原始YOLOv11模型进行剪枝。 2. **蒸馏**:使用剪枝后的模型作为学生模型,使用原始模型作为教师模型进行蒸馏。 3. **微调和优化**:对剪枝和蒸馏后的模型进行微调,以确保模型在保持较高精度的同时,具有更快的推理速度。 通过剪枝和蒸馏的结合,可以显著减少YOLOv11模型的计算复杂度和参数数量,从而提高其在实际应用中的效率和性能。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值