YOLOv11模型剪枝全流程实战

目录

一. 约束训练(Constrained Training)

二. 剪枝(Pruning)

注意:此代码只适用于YOLO11官方架构,不适用于其他版本和自创模型,否则需要重构prune_yolo11.py文件!

三. 回调训练(20250411更新)


模型剪枝(Model Pruning)是一种通过移除神经网络中冗余参数或结构来压缩模型的技术,旨在减少计算量、降低内存占用,同时尽量保持模型性能。

参考:

yolov8模型剪枝_yolov8模型的剪枝-CSDN博客

yolov11剪枝-CSDN博客

Ultralytics YOLO11 -Ultralytics YOLO 文档

YOLOv8源码修改(4)- YOLOv8剪枝(实现任意YOLO模型的简单剪枝)_yolov8模型剪枝-CSDN博客

一. 约束训练(Constrained Training)

模型剪枝通常是为了减少模型的大小和计算量,通过移除不重要的参数或结构。但直接剪枝可能会导致性能下降,所以需要在剪枝前进行一些处理,比如约束训练。约束训练可能是指在训练过程中引入某种约束,使得模型的结构更容易被剪枝,同时保持准确性。

直接对普通训练的模型进行剪枝会导致严重的性能下降,因为:

  • 参数冗余性差:普通训练后的模型参数分布可能不够稀疏,难以区分重要与非重要参数。

  • 结构耦合度高:层间依赖性强,直接剪枝会破坏特征传递路径。

  • 敏感性差异大:不同层对剪枝的容忍度不同,需针对性调整。

约束训练通过引入特定约束,使模型逐步适应未来的剪枝结构,减少剪枝后的性能损失。

在本文中,采用的是L1正则化来促进BatchNorm层权重的稀疏性,从而方便后续的剪枝:

操作如下:在ultralytics/engine/trainer.py中的BaseTrainer类_do_train方法中,添加以下代码:

l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

然后启动训练:

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
from ultralytics.models import RTDETR

if __name__ == '__main__':
    # model = RTDETR(r'ultralytics/cfg/models/rt-detr/rtdetr-l.yaml') 
    model = YOLO(r"ultralytics/cfg/models/11/yolo11m.yaml")
    model.train(data=r'own.yaml',
                cache=False,
                imgsz=640,
                epochs=30,
                single_cls=False,  # 是否是单类别检测
                batch=16,
                close_mosaic=10,
                workers=0,
                device=0,
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='Constrained Training YOLO11m',
                )

打印模型参数:

                   from  n    params  module                                       arguments
  0                  -1  1      1856  ultralytics.nn.modules.conv.Conv             [3, 64, 3, 2]
  1                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  2                  -1  1    111872  ultralytics.nn.modules.block.C3k2            [128, 256, 1, True, 0.25]
  3                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
  4                  -1  1    444928  ultralytics.nn.modules.block.C3k2            [256, 512, 1, True, 0.25]     
  5                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]
  6                  -1  1   1380352  ultralytics.nn.modules.block.C3k2            [512, 512, 1, True]
  7                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]
  8                  -1  1   1380352  ultralytics.nn.modules.block.C3k2            [512, 512, 1, True]
  9                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]
 10                  -1  1    990976  ultralytics.nn.modules.block.C2PSA           [512, 512, 1]
 11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 13                  -1  1   1642496  ultralytics.nn.modules.block.C3k2            [1024, 512, 1, True]
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 16                  -1  1    542720  ultralytics.nn.modules.block.C3k2            [1024, 256, 1, True]
 17                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]
 18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 19                  -1  1   1511424  ultralytics.nn.modules.block.C3k2            [768, 512, 1, True]
 20                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]
 21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 22                  -1  1   1642496  ultralytics.nn.modules.block.C3k2            [1024, 512, 1, True]
 23        [16, 19, 22]  1   1412566  ultralytics.nn.modules.head.Detect           [2, [256, 512, 512]]
YOLO11m summary: 231 layers, 20,054,550 parameters, 20,054,534 gradients, 68.2 GFLOPs

这时候训练好的权重文件等放在runs/train/Constrained Training YOLO11m下,等待备用。

跑一下验证:

YOLO11m summary (fused): 125 layers, 20,031,574 parameters, 0 gradients, 67.7 GFLOPs
val: Scanning /home/hairou/ctc/yolo11/Dataset/labels/val.cache... 2000 images, 65 backgrounds, 0 corrupt: 100%|██████████| 2065/2065 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 65/65 [00:15<00:00,  4.33it/s]
                   all       2065       4914      0.996      0.982      0.995      0.991
       shallow_box_rgb       2000       2027      0.994      0.989      0.995      0.992
  shallow_half_box_rgb       1823       2887      0.998      0.975      0.994       0.99
Speed: 0.2ms preprocess, 4.2ms inference, 0.0ms loss, 0.6ms postprocess per image
Results saved to runs/val/exp2

二. 剪枝(Pruning)

参考文献的代码略有一些问题,我进行了一些修改。

这部分的主要思路是结合L1正则化促进权重的稀疏性,然后剪除那些接近零的权重。首先定义了一个PRUNE类,包含获取阈值、剪枝卷积层和剪枝模块的方法。然后设计do_pruning函数,负责加载模型、执行剪枝步骤,并保存结果。主函数调用do_pruning,指定模型路径和保存路径。

我们在项目文件夹下新建文件prune_yolo11.py,其中有三个参数:

  • 代码倒数第三行:modelpath是我们之前约束训练的pt模型文件地址
  • 代码倒数第二行:savepath是设置的剪枝后pt模型保存地址
  • do_pruning函数下的pruning.get_threshold(yolo.model, 0.8),此处的0.8是剪枝率,可根据需要修改
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os
 
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
 
class PRUNE():
    def __init__(self) -> None:
        self.threshold = None
 
    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
 
    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## Normal Pruning
        gamma = conv1.bn.weight.data.detach()
        beta = conv1.bn.bias.data.detach()
 
        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n
 
        if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
            proto = conv2.pop()
            proto.cv1.conv.in_channels = n
            proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
 
        ## Regular Pruning
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            if isinstance(item, Sequential):
                conv1 = item[0]
                conv = item[1].conv
                conv1.conv.in_channels = n
                conv1.conv.out_channels = n
                conv1.conv.groups = n
                conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
                conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
                conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
                conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
                conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
                conv1.bn.num_features = n
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
 
    def prune(self, m1, m2):
        if isinstance(m1, C3k2):  # C3k2 as a top conv
            m1 = m1.cv2
        if isinstance(m1, Sequential):
            m1 = m1[1]
        if not isinstance(m2, list):  # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C3k2) or isinstance(item, SPPF):
                m2[i] = item.cv1
 
        self.prune_conv(m1, m2)
 
 
def do_pruning(modelpath, savepath):
    pruning = PRUNE()
 
    ### 0. 加载模型
    yolo = YOLO(modelpath)  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.8)  # 这里的0.8为剪枝率。
 
    ### 1. 剪枝C3k2 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)
 
    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3, 5, 7, 8]:
        pruning.prune(seq[i], seq[i + 1])
 
    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
    detect: Detect = seq[-1]
    # proto = detect.proto
    last_inputs = [seq[16], seq[19], seq[22]]
    colasts = [seq[17], seq[20], None]
    for idx, (last_input, colast, cv2, cv3) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3)):
        if idx == 0:
            pruning.prune(last_input, [colast, cv2[0], cv3[0]])
        else:
            pruning.prune(last_input, [colast, cv2[0], cv3[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])
 
    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    yolo.info()
 
    yolo.val(data='own.yaml', batch=16, device=0, workers=0)
    torch.save(yolo.ckpt, savepath)

if __name__ == "__main__":
    modelpath = "ultralytics-main/runs/train/Constrained Training YOLO11m/weights/best.pt"
    savepath = "ultralytics-main/runs/train/Constrained Training YOLO11m/weights/last_prune.pt"
    do_pruning(modelpath, savepath)

打印YOLO11各层结构,如下:

model.0.bn 1.298828125 0.8505859375 0.29248046875 0.002437591552734375
model.1.bn 1.1142578125 0.9462890625 0.1566162109375 4.1961669921875e-05
model.2.cv1.bn 1.08984375 0.9267578125 0.15673828125 0.0005412101745605469
model.2.cv2.bn 1.03515625 0.9755859375 0.054290771484375 4.0531158447265625e-05
model.2.m.0.cv1.bn 1.103515625 0.89111328125 0.219970703125 0.0010194778442382812
model.2.m.0.cv2.bn 1.0498046875 0.9658203125 0.051300048828125 0.00019085407257080078
model.2.m.0.cv3.bn 1.0966796875 0.97119140625 0.06610107421875 0.00041937828063964844
model.2.m.0.m.0.cv1.bn 1.107421875 0.94140625 0.1298828125 0.0010995864868164062
model.2.m.0.m.0.cv2.bn 1.123046875 0.9296875 0.06353759765625 0.00021064281463623047
model.2.m.0.m.1.cv1.bn 1.1025390625 0.912109375 0.0673828125 0.0004315376281738281
model.2.m.0.m.1.cv2.bn 1.119140625 0.982421875 0.0810546875 0.00025916099548339844
model.3.bn 1.04296875 0.978515625 0.052581787109375 4.416704177856445e-05
model.4.cv1.bn 1.0244140625 0.96435546875 0.05096435546875 3.212690353393555e-05
model.4.cv2.bn 1.02734375 0.970703125 0.0245513916015625 3.7550926208496094e-06
model.4.m.0.cv1.bn 1.0234375 0.9150390625 0.049835205078125 6.216764450073242e-05
model.4.m.0.cv2.bn 1.0029296875 0.9765625 0.016265869140625 0.0001055002212524414
model.4.m.0.cv3.bn 1.05859375 0.99365234375 0.042144775390625 1.52587890625e-05
model.4.m.0.m.0.cv1.bn 1.0576171875 0.96826171875 0.037139892578125 0.00024330615997314453
model.4.m.0.m.0.cv2.bn 1.0498046875 0.974609375 0.0245819091796875 8.064508438110352e-05
model.4.m.0.m.1.cv1.bn 1.1005859375 0.9638671875 0.0321044921875 0.000537872314453125
model.4.m.0.m.1.cv2.bn 1.0791015625 0.99560546875 0.02862548828125 9.72747802734375e-05
model.5.bn 1.0234375 0.986328125 0.02447509765625 1.33514404296875e-05
model.6.cv1.bn 1.0224609375 0.96728515625 0.0301513671875 1.8894672393798828e-05
model.6.cv2.bn 1.0478515625 0.98681640625 0.0259857177734375 1.3053417205810547e-05
model.6.m.0.cv1.bn 1.0244140625 0.91455078125 0.045501708984375 0.0001748800277709961
model.6.m.0.cv2.bn 1.0009765625 0.9833984375 0.01180267333984375 8.344650268554688e-06
model.6.m.0.cv3.bn 1.04296875 0.98876953125 0.02777099609375 1.245737075805664e-05
model.6.m.0.m.0.cv1.bn 1.0615234375 0.98046875 0.0362548828125 7.94529914855957e-05
model.6.m.0.m.0.cv2.bn 1.0615234375 0.98291015625 0.03216552734375 0.00010192394256591797
model.6.m.0.m.1.cv1.bn 1.0556640625 0.9814453125 0.037689208984375 4.4226646423339844e-05
model.6.m.0.m.1.cv2.bn 1.0576171875 1.005859375 0.02435302734375 5.143880844116211e-05
model.7.bn 1.0087890625 0.99072265625 0.0201416015625 7.212162017822266e-06
model.8.cv1.bn 1.0087890625 0.98486328125 0.025665283203125 7.152557373046875e-07
model.8.cv2.bn 1.0126953125 0.98583984375 0.0298004150390625 3.159046173095703e-06
model.8.m.0.cv1.bn 1.01171875 0.98193359375 0.0174560546875 4.708766937255859e-06
model.8.m.0.cv2.bn 1.0 0.99267578125 0.00710296630859375 1.4185905456542969e-05
model.8.m.0.cv3.bn 1.015625 0.99169921875 0.022247314453125 1.7881393432617188e-07
model.8.m.0.m.0.cv1.bn 1.0126953125 0.9873046875 0.01384735107421875 1.049041748046875e-05
model.8.m.0.m.0.cv2.bn 1.013671875 0.9921875 0.01549530029296875 8.344650268554688e-06
model.8.m.0.m.1.cv1.bn 1.0068359375 0.99169921875 0.01239013671875 8.344650268554688e-07
model.8.m.0.m.1.cv2.bn 1.01953125 0.9951171875 0.0211639404296875 7.152557373046875e-07
model.9.cv1.bn 1.03515625 0.9794921875 0.0255126953125 1.6689300537109375e-06
model.9.cv2.bn 1.017578125 0.98876953125 0.040008544921875 1.0669231414794922e-05
model.10.cv1.bn 1.0302734375 0.98095703125 0.057891845703125 1.7881393432617188e-06
model.10.cv2.bn 1.0205078125 0.98388671875 0.0238494873046875 4.0531158447265625e-06
model.10.m.0.attn.qkv.bn 1.0087890625 0.98291015625 0.2484130859375 0.0
model.10.m.0.attn.proj.bn 1.013671875 0.9921875 0.0 0.0
model.10.m.0.attn.pe.bn 1.03515625 0.98046875 0.0 0.0
model.10.m.0.ffn.0.bn 1.0068359375 0.99267578125 0.01508331298828125 4.172325134277344e-07
model.10.m.0.ffn.1.bn 1.005859375 0.99072265625 0.0 0.0
model.13.cv1.bn 1.041015625 0.95556640625 0.032257080078125 2.5033950805664062e-06
model.13.cv2.bn 1.0498046875 0.984375 0.0338134765625 1.3828277587890625e-05
model.13.m.0.cv1.bn 1.025390625 0.9423828125 0.0546875 0.0002161264419555664
model.13.m.0.cv2.bn 0.99951171875 0.982421875 0.01377105712890625 1.6748905181884766e-05
model.13.m.0.cv3.bn 1.048828125 0.98291015625 0.0379638671875 1.3113021850585938e-05
model.13.m.0.m.0.cv1.bn 1.1123046875 0.97900390625 0.051513671875 2.282857894897461e-05
model.13.m.0.m.0.cv2.bn 1.09375 0.97900390625 0.04693603515625 0.00019729137420654297
model.13.m.0.m.1.cv1.bn 1.1162109375 0.96826171875 0.03106689453125 9.834766387939453e-06
model.13.m.0.m.1.cv2.bn 1.0986328125 0.99755859375 0.04931640625 0.00026988983154296875
model.16.cv1.bn 1.0810546875 0.9296875 0.06439208984375 4.0531158447265625e-05
model.16.cv2.bn 1.1484375 0.8828125 0.09271240234375 0.0001844167709350586
model.16.m.0.cv1.bn 1.0068359375 0.9208984375 0.05078125 7.933378219604492e-05
model.16.m.0.cv2.bn 1.0078125 0.97509765625 0.0187225341796875 0.0002949237823486328
model.16.m.0.cv3.bn 1.0654296875 0.97412109375 0.05389404296875 0.0001404285430908203
model.16.m.0.m.0.cv1.bn 1.0908203125 0.958984375 0.034576416015625 0.00022542476654052734
model.16.m.0.m.0.cv2.bn 1.10546875 0.9638671875 0.054840087890625 0.00028228759765625
model.16.m.0.m.1.cv1.bn 1.1123046875 0.95654296875 0.03814697265625 0.0008211135864257812
model.16.m.0.m.1.cv2.bn 1.0810546875 0.98779296875 0.045806884765625 1.811981201171875e-05
model.17.bn 1.0107421875 0.9912109375 0.01197052001953125 8.344650268554688e-07
model.19.cv1.bn 1.0166015625 0.98779296875 0.0159149169921875 8.940696716308594e-07
model.19.cv2.bn 1.0234375 0.98486328125 0.0160980224609375 2.1457672119140625e-06
model.19.m.0.cv1.bn 1.00390625 0.96923828125 0.0095672607421875 1.7881393432617188e-06
model.19.m.0.cv2.bn 1.0009765625 0.99365234375 0.00296783447265625 4.76837158203125e-07
model.19.m.0.cv3.bn 1.013671875 0.9892578125 0.0101470947265625 2.1457672119140625e-06
model.19.m.0.m.0.cv1.bn 1.0244140625 0.98681640625 0.00632476806640625 3.7550926208496094e-06
model.19.m.0.m.0.cv2.bn 1.0458984375 0.99072265625 0.0086669921875 7.808208465576172e-06
model.19.m.0.m.1.cv1.bn 1.0263671875 0.9892578125 0.00522613525390625 2.980232238769531e-07
model.19.m.0.m.1.cv2.bn 1.025390625 0.9921875 0.008575439453125 4.172325134277344e-07
model.20.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.cv1.bn 1.0009765625 0.99853515625 0.0 0.0
model.22.cv2.bn 0.99951171875 0.9990234375 0.0007128715515136719 0.0
model.22.m.0.cv1.bn 0.99951171875 0.99853515625 0.0 0.0
model.22.m.0.cv2.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.cv3.bn 1.0 0.99853515625 0.0 0.0
model.22.m.0.m.0.cv1.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.m.0.cv2.bn 0.99951171875 0.99853515625 0.0 0.0
model.22.m.0.m.1.cv1.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.m.1.cv2.bn 0.99951171875 0.9990234375 0.0 0.0
model.23.cv2.0.0.bn 1.158203125 0.9267578125 0.11334228515625 0.0004303455352783203
model.23.cv2.0.1.bn 1.255859375 0.9365234375 0.360107421875 0.0012311935424804688
model.23.cv2.1.0.bn 1.099609375 0.962890625 0.0748291015625 0.00022161006927490234
model.23.cv2.1.1.bn 1.2255859375 0.96875 0.34814453125 0.0012159347534179688
model.23.cv2.2.0.bn 0.9990234375 0.9990234375 0.0 0.0
model.23.cv2.2.1.bn 0.9990234375 0.9990234375 0.0 0.0
model.23.cv3.0.0.0.bn 1.1171875 0.951171875 0.0545654296875 9.894371032714844e-06
model.23.cv3.0.0.1.bn 1.005859375 0.96630859375 0.05450439453125 9.59634780883789e-06
model.23.cv3.0.1.0.bn 1.12109375 0.9296875 0.059326171875 1.9431114196777344e-05
model.23.cv3.0.1.1.bn 1.1337890625 0.91748046875 0.2464599609375 0.0006527900695800781
model.23.cv3.1.0.0.bn 1.0263671875 0.98681640625 0.0114593505859375 5.960464477539063e-08
model.23.cv3.1.0.1.bn 1.00390625 0.9775390625 0.0640869140625 1.1920928955078125e-07
model.23.cv3.1.1.0.bn 1.0478515625 0.970703125 0.0250396728515625 9.5367431640625e-07
model.23.cv3.1.1.1.bn 1.1005859375 0.9873046875 0.2313232421875 0.0007276535034179688
model.23.cv3.2.0.0.bn 1.0 0.998046875 0.0006380081176757812 0.0
model.23.cv3.2.0.1.bn 0.99951171875 0.9990234375 0.0020084381103515625 0.0
model.23.cv3.2.1.0.bn 1.0009765625 0.9951171875 0.003658294677734375 0.0
model.23.cv3.2.1.1.bn 1.0302734375 0.9921875 0.193115234375 0.0001518726348876953

注意:此代码只适用于YOLO11官方架构,不适用于其他版本和自创模型,否则需要重构prune_yolo11.py文件!

运行prune_yolo11.py之后,会生成一个pt文件,该文件是剪枝后的:

此时我们验证一下剪枝后的文件:

YOLO11m summary (fused): 125 layers, 17,723,891 parameters, 13,270 gradients, 56.4 GFLOPs
val: Scanning /home/hairou/ctc/yolo11/Dataset/labels/val.cache... 2000 images, 65 backgrounds, 0 corrupt: 100%|██████████| 2065/2065 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 130/130 [00:28<00:00,  4.53it/s]
                   all       2065       4914        0.5   0.000247       0.25        0.2
       shallow_box_rgb       2000       2027          1   0.000493        0.5        0.4
  shallow_half_box_rgb       1823       2887          0          0          0          0
Speed: 0.1ms preprocess, 3.6ms inference, 0.0ms loss, 0.3ms postprocess per image

可以看到参数量和计算量均有所下降,但是模型精度大大降低,所以需要第三步回调训练。

笔者注:按此代码保存的文件,生成的last_prune.pt(67.85MB)文件比原best.pt(38.63MB)大属于正常现象。

三. 回调训练(20250411更新)

回调训练和普通训练的区别在于:回调训练需要加载第二步剪枝好的pt模型,但是YOLO框架默认加载原YAML文件,因此如果直接使用如下代码,最后加载的还是原来的YOLO11m文件:

# 经典的错误,标准的零分
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
from ultralytics.models import RTDETR

if __name__ == '__main__':
    # model = RTDETR(r'ultralytics/cfg/models/rt-detr/rtdetr-l.yaml') 
    model = YOLO(r"runs/train/Constrained Training YOLO11m/weights/last_prune.pt")
    model.train(data=r'own.yaml',
                cache=False,
                imgsz=640,
                epochs=30,
                single_cls=False,  # 是否是单类别检测
                batch=16,
                close_mosaic=10,
                workers=0,
                device=0,
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

那么怎么更改这个默认加载的逻辑呢?

第一步:将约束训练中ultralytics/engine/trainer.py中的BaseTrainer类_do_train方法中添加的代码删除(从略)

第二步:在ultralytics/engine/trainer.py中的BaseTrainer类setup_model方法中添加self.model = weights:

第三步:在ultralytics/engine/model.py中,修改Model类的train方法:

# ultralytics/engine/model.py 中实现自定义结构模型加载  
            
print("-----------------------------------")
print(f"\033[1;32mINFO\033[0m: custom_model is True, load custom model. ")
for name, param in self.model.named_parameters():
    if "dfl" in name:
        param.requires_grad = False  # 冻结
    else:
        param.requires_grad = True  # 解冻其他层
    self.trainer.model.model = self.model.model

最后我们再启动训练,可以看到模型训练内存占用(5.52G)相比于前面的约束训练(7.79G)明显减小:

最后训练出的pt文件也只有33.98MB,相比于原模型38.63MB有明显下降:

剪枝前:231 layers, 20,054,550 parameters, 20,054,534 gradients, 68.2 GFLOPs
剪枝后:125 layers, 17,723,891 parameters, 0 gradients, 56.4 GFLOPs
<think>嗯,用户问的是YOLOv11模型剪枝,我需要先确认YOLOv11是不是最新的模型。因为YOLO系列通常更新到YOLOv8或者YOLOv9,可能用户指的是YOLOv7或者YOLOv5的剪枝方法?或者是某个变种?不过先假设YOLOv11存在,或者可能用户打错了,比如YOLOv1或者YOLOv3? 接下来,模型剪枝的基本概念,剪枝主要是为了减少模型大小和计算量,同时保持性能。常见的剪枝方法有结构化剪枝和非结构化剪枝,比如基于权重大小的剪枝,或者通道剪枝。 需要介绍YOLO模型剪枝的一般步骤:可能包括训练一个基准模型,评估通道或层的重要性,剪枝不重要的部分,然后微调恢复性能。但具体到YOLOv11,可能需要查阅相关论文或资料,但目前可能没有公开的信息,因为YOLO系列官方可能没有到v11版本。 这时候需要确认是否存在YOLOv11,可能用户指的是某个第三方改进的版本,或者是笔误。如果是笔误,可能需要纠正并解释常见的YOLO剪枝方法,比如YOLOv3、YOLOv4的剪枝策略。 另外,用户可能希望了解剪枝的实际应用,比如在移动端部署,或者嵌入式设备上使用。需要提到剪枝后的模型在速度和精度之间的权衡,以及常用的工具库,如PyTorch的剪枝工具,或者第三方库如NNI、Torch-Pruning等。 需要确保回答准确,如果YOLOv11不存在,需要指出可能的误解,并提供通用的YOLO剪枝方法。同时,保持回答结构清晰,分步骤解释,使用例子和注意事项,帮助用户理解整个流程。</think>关于模型剪枝,这里以YOLO系列(如YOLOv5/YOLOv7/YOLOv8)的通用剪枝方法为例进行说明: ### 1. 模型剪枝的核心目标 - 减少参数量和计算量 - 提升推理速度(尤其对移动端/边缘设备) - 保持模型精度损失最小化 ### 2. YOLO剪枝常用方法 #### ▶ 通道剪枝(Channel Pruning) **原理**:移除卷积层中不重要的输出通道 **步骤**: 1. 训练原始模型至收敛 2. 计算BN层γ系数的L1范数 3. 按阈值剪枝γ值小的通道 4. 微调剪枝模型 示例公式(BN层重要性评估): $$ importance = \sum_{c=1}^{C} |\gamma_c| $$ #### ▶ 层剪枝(Layer Pruning) - 移除冗余的backbone层或neck层 - 需分析每层对mAP的贡献度 #### ▶ 结构化剪枝 通过约束通道数实现: ```python # 稀疏化训练示例(PyTorch) optimizer = torch.optim.SGD([ {'params': model.bn_params(), 'weight_decay': 0}, {'params': model.other_params()} ], lr=0.01, momentum=0.9) ``` ### 3. YOLO剪枝典型流程 1. **基准模型训练**:完整结构训练至收敛 2. **重要性分析**: - 使用BN层γ系数评估通道重要性 - 或计算特征图激活值稀疏度 3. **迭代剪枝**: - 单次剪枝比例建议≤20% - 多次剪枝+微调的组合策略 4. **微调恢复**: - 使用原数据集1/10~1/5数据量 - 较低学习率(如初始lr的1/10) ### 4. 注意事项 - 剪枝顺序:建议从浅层到深层 - 敏感层保护:避免剪枝第一个卷积层 - 精度监控:每次剪枝后验证mAP变化 - 硬件适配:需考虑目标设备的计算单元特性 ### 5. 推荐工具 ```bash # 常用工具库 pip install torch-pruning # 结构化剪枝工具 pip install nni # 微软神经网络智能压缩 ``` ### 6. 典型结果参考 | 模型 | 剪枝率 | mAP变化 | 速度提升 | |------------|--------|---------|----------| | YOLOv5s | 40% | -1.2% | +65% | | YOLOv7-tiny| 50% | -2.1% | +83% | 建议从官方代码库(如Ultralytics YOLO)的prune.py入手实践,结合Visdom或TensorBoard监控剪枝过程中的精度变化。
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

北京地铁1号线

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值