目录
注意:此代码只适用于YOLO11官方架构,不适用于其他版本和自创模型,否则需要重构prune_yolo11.py文件!
模型剪枝(Model Pruning)是一种通过移除神经网络中冗余参数或结构来压缩模型的技术,旨在减少计算量、降低内存占用,同时尽量保持模型性能。
参考:
Ultralytics YOLO11 -Ultralytics YOLO 文档
YOLOv8源码修改(4)- YOLOv8剪枝(实现任意YOLO模型的简单剪枝)_yolov8模型剪枝-CSDN博客
一. 约束训练(Constrained Training)
模型剪枝通常是为了减少模型的大小和计算量,通过移除不重要的参数或结构。但直接剪枝可能会导致性能下降,所以需要在剪枝前进行一些处理,比如约束训练。约束训练可能是指在训练过程中引入某种约束,使得模型的结构更容易被剪枝,同时保持准确性。
直接对普通训练的模型进行剪枝会导致严重的性能下降,因为:
-
参数冗余性差:普通训练后的模型参数分布可能不够稀疏,难以区分重要与非重要参数。
-
结构耦合度高:层间依赖性强,直接剪枝会破坏特征传递路径。
-
敏感性差异大:不同层对剪枝的容忍度不同,需针对性调整。
约束训练通过引入特定约束,使模型逐步适应未来的剪枝结构,减少剪枝后的性能损失。
在本文中,采用的是L1正则化来促进BatchNorm层权重的稀疏性,从而方便后续的剪枝:
操作如下:在ultralytics/engine/trainer.py中的BaseTrainer类_do_train方法中,添加以下代码:
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
然后启动训练:
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
from ultralytics.models import RTDETR
if __name__ == '__main__':
# model = RTDETR(r'ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
model = YOLO(r"ultralytics/cfg/models/11/yolo11m.yaml")
model.train(data=r'own.yaml',
cache=False,
imgsz=640,
epochs=30,
single_cls=False, # 是否是单类别检测
batch=16,
close_mosaic=10,
workers=0,
device=0,
optimizer='SGD',
amp=True,
project='runs/train',
name='Constrained Training YOLO11m',
)
打印模型参数:
from n params module arguments
0 -1 1 1856 ultralytics.nn.modules.conv.Conv [3, 64, 3, 2]
1 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
2 -1 1 111872 ultralytics.nn.modules.block.C3k2 [128, 256, 1, True, 0.25]
3 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
4 -1 1 444928 ultralytics.nn.modules.block.C3k2 [256, 512, 1, True, 0.25]
5 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
6 -1 1 1380352 ultralytics.nn.modules.block.C3k2 [512, 512, 1, True]
7 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
8 -1 1 1380352 ultralytics.nn.modules.block.C3k2 [512, 512, 1, True]
9 -1 1 656896 ultralytics.nn.modules.block.SPPF [512, 512, 5]
10 -1 1 990976 ultralytics.nn.modules.block.C2PSA [512, 512, 1]
11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1]
13 -1 1 1642496 ultralytics.nn.modules.block.C3k2 [1024, 512, 1, True]
14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1]
16 -1 1 542720 ultralytics.nn.modules.block.C3k2 [1024, 256, 1, True]
17 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
18 [-1, 13] 1 0 ultralytics.nn.modules.conv.Concat [1]
19 -1 1 1511424 ultralytics.nn.modules.block.C3k2 [768, 512, 1, True]
20 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
21 [-1, 10] 1 0 ultralytics.nn.modules.conv.Concat [1]
22 -1 1 1642496 ultralytics.nn.modules.block.C3k2 [1024, 512, 1, True]
23 [16, 19, 22] 1 1412566 ultralytics.nn.modules.head.Detect [2, [256, 512, 512]]
YOLO11m summary: 231 layers, 20,054,550 parameters, 20,054,534 gradients, 68.2 GFLOPs
这时候训练好的权重文件等放在runs/train/Constrained Training YOLO11m下,等待备用。
跑一下验证:
YOLO11m summary (fused): 125 layers, 20,031,574 parameters, 0 gradients, 67.7 GFLOPs
val: Scanning /home/hairou/ctc/yolo11/Dataset/labels/val.cache... 2000 images, 65 backgrounds, 0 corrupt: 100%|██████████| 2065/2065 [00:00<?, ?it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 65/65 [00:15<00:00, 4.33it/s]
all 2065 4914 0.996 0.982 0.995 0.991
shallow_box_rgb 2000 2027 0.994 0.989 0.995 0.992
shallow_half_box_rgb 1823 2887 0.998 0.975 0.994 0.99
Speed: 0.2ms preprocess, 4.2ms inference, 0.0ms loss, 0.6ms postprocess per image
Results saved to runs/val/exp2
二. 剪枝(Pruning)
参考文献的代码略有一些问题,我进行了一些修改。
这部分的主要思路是结合L1正则化促进权重的稀疏性,然后剪除那些接近零的权重。首先定义了一个PRUNE类,包含获取阈值、剪枝卷积层和剪枝模块的方法。然后设计do_pruning函数,负责加载模型、执行剪枝步骤,并保存结果。主函数调用do_pruning,指定模型路径和保存路径。
我们在项目文件夹下新建文件prune_yolo11.py,其中有三个参数:
- 代码倒数第三行:modelpath是我们之前约束训练的pt模型文件地址
- 代码倒数第二行:savepath是设置的剪枝后pt模型保存地址
- do_pruning函数下的pruning.get_threshold(yolo.model, 0.8),此处的0.8是剪枝率,可根据需要修改
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
class PRUNE():
def __init__(self) -> None:
self.threshold = None
def get_threshold(self, model, factor=0.8):
ws = []
bs = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
ws = torch.cat(ws)
self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
def prune_conv(self, conv1: Conv, conv2: Conv):
## Normal Pruning
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
keep_idxs = []
local_threshold = self.threshold
while len(keep_idxs) < 8: ## 若剩余卷积核<8, 则降低阈值重新筛选
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
local_threshold = local_threshold * 0.5
n = len(keep_idxs)
# n = max(int(len(idxs) * 0.8), p)
print(n / len(gamma) * 100)
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
conv1.conv.out_channels = n
if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
proto = conv2.pop()
proto.cv1.conv.in_channels = n
proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
## Regular Pruning
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is None: continue
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
if isinstance(item, Sequential):
conv1 = item[0]
conv = item[1].conv
conv1.conv.in_channels = n
conv1.conv.out_channels = n
conv1.conv.groups = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
conv.in_channels = n
conv.weight.data = conv.weight.data[:, keep_idxs]
def prune(self, m1, m2):
if isinstance(m1, C3k2): # C3k2 as a top conv
m1 = m1.cv2
if isinstance(m1, Sequential):
m1 = m1[1]
if not isinstance(m2, list): # m2 is just one module
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C3k2) or isinstance(item, SPPF):
m2[i] = item.cv1
self.prune_conv(m1, m2)
def do_pruning(modelpath, savepath):
pruning = PRUNE()
### 0. 加载模型
yolo = YOLO(modelpath) # build a new model from scratch
pruning.get_threshold(yolo.model, 0.8) # 这里的0.8为剪枝率。
### 1. 剪枝C3k2 中的Bottleneck
for name, m in yolo.model.named_modules():
if isinstance(m, Bottleneck):
pruning.prune_conv(m.cv1, m.cv2)
### 2. 指定剪枝不同模块之间的卷积核
seq = yolo.model.model
for i in [3, 5, 7, 8]:
pruning.prune(seq[i], seq[i + 1])
### 3. 对检测头进行剪枝
# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
# 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
detect: Detect = seq[-1]
# proto = detect.proto
last_inputs = [seq[16], seq[19], seq[22]]
colasts = [seq[17], seq[20], None]
for idx, (last_input, colast, cv2, cv3) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3)):
if idx == 0:
pruning.prune(last_input, [colast, cv2[0], cv3[0]])
else:
pruning.prune(last_input, [colast, cv2[0], cv3[0]])
pruning.prune(cv2[0], cv2[1])
pruning.prune(cv2[1], cv2[2])
pruning.prune(cv3[0], cv3[1])
pruning.prune(cv3[1], cv3[2])
### 4. 模型梯度设置与保存
for name, p in yolo.model.named_parameters():
p.requires_grad = True
yolo.info()
yolo.val(data='own.yaml', batch=16, device=0, workers=0)
torch.save(yolo.ckpt, savepath)
if __name__ == "__main__":
modelpath = "ultralytics-main/runs/train/Constrained Training YOLO11m/weights/best.pt"
savepath = "ultralytics-main/runs/train/Constrained Training YOLO11m/weights/last_prune.pt"
do_pruning(modelpath, savepath)
打印YOLO11各层结构,如下:
model.0.bn 1.298828125 0.8505859375 0.29248046875 0.002437591552734375
model.1.bn 1.1142578125 0.9462890625 0.1566162109375 4.1961669921875e-05
model.2.cv1.bn 1.08984375 0.9267578125 0.15673828125 0.0005412101745605469
model.2.cv2.bn 1.03515625 0.9755859375 0.054290771484375 4.0531158447265625e-05
model.2.m.0.cv1.bn 1.103515625 0.89111328125 0.219970703125 0.0010194778442382812
model.2.m.0.cv2.bn 1.0498046875 0.9658203125 0.051300048828125 0.00019085407257080078
model.2.m.0.cv3.bn 1.0966796875 0.97119140625 0.06610107421875 0.00041937828063964844
model.2.m.0.m.0.cv1.bn 1.107421875 0.94140625 0.1298828125 0.0010995864868164062
model.2.m.0.m.0.cv2.bn 1.123046875 0.9296875 0.06353759765625 0.00021064281463623047
model.2.m.0.m.1.cv1.bn 1.1025390625 0.912109375 0.0673828125 0.0004315376281738281
model.2.m.0.m.1.cv2.bn 1.119140625 0.982421875 0.0810546875 0.00025916099548339844
model.3.bn 1.04296875 0.978515625 0.052581787109375 4.416704177856445e-05
model.4.cv1.bn 1.0244140625 0.96435546875 0.05096435546875 3.212690353393555e-05
model.4.cv2.bn 1.02734375 0.970703125 0.0245513916015625 3.7550926208496094e-06
model.4.m.0.cv1.bn 1.0234375 0.9150390625 0.049835205078125 6.216764450073242e-05
model.4.m.0.cv2.bn 1.0029296875 0.9765625 0.016265869140625 0.0001055002212524414
model.4.m.0.cv3.bn 1.05859375 0.99365234375 0.042144775390625 1.52587890625e-05
model.4.m.0.m.0.cv1.bn 1.0576171875 0.96826171875 0.037139892578125 0.00024330615997314453
model.4.m.0.m.0.cv2.bn 1.0498046875 0.974609375 0.0245819091796875 8.064508438110352e-05
model.4.m.0.m.1.cv1.bn 1.1005859375 0.9638671875 0.0321044921875 0.000537872314453125
model.4.m.0.m.1.cv2.bn 1.0791015625 0.99560546875 0.02862548828125 9.72747802734375e-05
model.5.bn 1.0234375 0.986328125 0.02447509765625 1.33514404296875e-05
model.6.cv1.bn 1.0224609375 0.96728515625 0.0301513671875 1.8894672393798828e-05
model.6.cv2.bn 1.0478515625 0.98681640625 0.0259857177734375 1.3053417205810547e-05
model.6.m.0.cv1.bn 1.0244140625 0.91455078125 0.045501708984375 0.0001748800277709961
model.6.m.0.cv2.bn 1.0009765625 0.9833984375 0.01180267333984375 8.344650268554688e-06
model.6.m.0.cv3.bn 1.04296875 0.98876953125 0.02777099609375 1.245737075805664e-05
model.6.m.0.m.0.cv1.bn 1.0615234375 0.98046875 0.0362548828125 7.94529914855957e-05
model.6.m.0.m.0.cv2.bn 1.0615234375 0.98291015625 0.03216552734375 0.00010192394256591797
model.6.m.0.m.1.cv1.bn 1.0556640625 0.9814453125 0.037689208984375 4.4226646423339844e-05
model.6.m.0.m.1.cv2.bn 1.0576171875 1.005859375 0.02435302734375 5.143880844116211e-05
model.7.bn 1.0087890625 0.99072265625 0.0201416015625 7.212162017822266e-06
model.8.cv1.bn 1.0087890625 0.98486328125 0.025665283203125 7.152557373046875e-07
model.8.cv2.bn 1.0126953125 0.98583984375 0.0298004150390625 3.159046173095703e-06
model.8.m.0.cv1.bn 1.01171875 0.98193359375 0.0174560546875 4.708766937255859e-06
model.8.m.0.cv2.bn 1.0 0.99267578125 0.00710296630859375 1.4185905456542969e-05
model.8.m.0.cv3.bn 1.015625 0.99169921875 0.022247314453125 1.7881393432617188e-07
model.8.m.0.m.0.cv1.bn 1.0126953125 0.9873046875 0.01384735107421875 1.049041748046875e-05
model.8.m.0.m.0.cv2.bn 1.013671875 0.9921875 0.01549530029296875 8.344650268554688e-06
model.8.m.0.m.1.cv1.bn 1.0068359375 0.99169921875 0.01239013671875 8.344650268554688e-07
model.8.m.0.m.1.cv2.bn 1.01953125 0.9951171875 0.0211639404296875 7.152557373046875e-07
model.9.cv1.bn 1.03515625 0.9794921875 0.0255126953125 1.6689300537109375e-06
model.9.cv2.bn 1.017578125 0.98876953125 0.040008544921875 1.0669231414794922e-05
model.10.cv1.bn 1.0302734375 0.98095703125 0.057891845703125 1.7881393432617188e-06
model.10.cv2.bn 1.0205078125 0.98388671875 0.0238494873046875 4.0531158447265625e-06
model.10.m.0.attn.qkv.bn 1.0087890625 0.98291015625 0.2484130859375 0.0
model.10.m.0.attn.proj.bn 1.013671875 0.9921875 0.0 0.0
model.10.m.0.attn.pe.bn 1.03515625 0.98046875 0.0 0.0
model.10.m.0.ffn.0.bn 1.0068359375 0.99267578125 0.01508331298828125 4.172325134277344e-07
model.10.m.0.ffn.1.bn 1.005859375 0.99072265625 0.0 0.0
model.13.cv1.bn 1.041015625 0.95556640625 0.032257080078125 2.5033950805664062e-06
model.13.cv2.bn 1.0498046875 0.984375 0.0338134765625 1.3828277587890625e-05
model.13.m.0.cv1.bn 1.025390625 0.9423828125 0.0546875 0.0002161264419555664
model.13.m.0.cv2.bn 0.99951171875 0.982421875 0.01377105712890625 1.6748905181884766e-05
model.13.m.0.cv3.bn 1.048828125 0.98291015625 0.0379638671875 1.3113021850585938e-05
model.13.m.0.m.0.cv1.bn 1.1123046875 0.97900390625 0.051513671875 2.282857894897461e-05
model.13.m.0.m.0.cv2.bn 1.09375 0.97900390625 0.04693603515625 0.00019729137420654297
model.13.m.0.m.1.cv1.bn 1.1162109375 0.96826171875 0.03106689453125 9.834766387939453e-06
model.13.m.0.m.1.cv2.bn 1.0986328125 0.99755859375 0.04931640625 0.00026988983154296875
model.16.cv1.bn 1.0810546875 0.9296875 0.06439208984375 4.0531158447265625e-05
model.16.cv2.bn 1.1484375 0.8828125 0.09271240234375 0.0001844167709350586
model.16.m.0.cv1.bn 1.0068359375 0.9208984375 0.05078125 7.933378219604492e-05
model.16.m.0.cv2.bn 1.0078125 0.97509765625 0.0187225341796875 0.0002949237823486328
model.16.m.0.cv3.bn 1.0654296875 0.97412109375 0.05389404296875 0.0001404285430908203
model.16.m.0.m.0.cv1.bn 1.0908203125 0.958984375 0.034576416015625 0.00022542476654052734
model.16.m.0.m.0.cv2.bn 1.10546875 0.9638671875 0.054840087890625 0.00028228759765625
model.16.m.0.m.1.cv1.bn 1.1123046875 0.95654296875 0.03814697265625 0.0008211135864257812
model.16.m.0.m.1.cv2.bn 1.0810546875 0.98779296875 0.045806884765625 1.811981201171875e-05
model.17.bn 1.0107421875 0.9912109375 0.01197052001953125 8.344650268554688e-07
model.19.cv1.bn 1.0166015625 0.98779296875 0.0159149169921875 8.940696716308594e-07
model.19.cv2.bn 1.0234375 0.98486328125 0.0160980224609375 2.1457672119140625e-06
model.19.m.0.cv1.bn 1.00390625 0.96923828125 0.0095672607421875 1.7881393432617188e-06
model.19.m.0.cv2.bn 1.0009765625 0.99365234375 0.00296783447265625 4.76837158203125e-07
model.19.m.0.cv3.bn 1.013671875 0.9892578125 0.0101470947265625 2.1457672119140625e-06
model.19.m.0.m.0.cv1.bn 1.0244140625 0.98681640625 0.00632476806640625 3.7550926208496094e-06
model.19.m.0.m.0.cv2.bn 1.0458984375 0.99072265625 0.0086669921875 7.808208465576172e-06
model.19.m.0.m.1.cv1.bn 1.0263671875 0.9892578125 0.00522613525390625 2.980232238769531e-07
model.19.m.0.m.1.cv2.bn 1.025390625 0.9921875 0.008575439453125 4.172325134277344e-07
model.20.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.cv1.bn 1.0009765625 0.99853515625 0.0 0.0
model.22.cv2.bn 0.99951171875 0.9990234375 0.0007128715515136719 0.0
model.22.m.0.cv1.bn 0.99951171875 0.99853515625 0.0 0.0
model.22.m.0.cv2.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.cv3.bn 1.0 0.99853515625 0.0 0.0
model.22.m.0.m.0.cv1.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.m.0.cv2.bn 0.99951171875 0.99853515625 0.0 0.0
model.22.m.0.m.1.cv1.bn 0.99951171875 0.9990234375 0.0 0.0
model.22.m.0.m.1.cv2.bn 0.99951171875 0.9990234375 0.0 0.0
model.23.cv2.0.0.bn 1.158203125 0.9267578125 0.11334228515625 0.0004303455352783203
model.23.cv2.0.1.bn 1.255859375 0.9365234375 0.360107421875 0.0012311935424804688
model.23.cv2.1.0.bn 1.099609375 0.962890625 0.0748291015625 0.00022161006927490234
model.23.cv2.1.1.bn 1.2255859375 0.96875 0.34814453125 0.0012159347534179688
model.23.cv2.2.0.bn 0.9990234375 0.9990234375 0.0 0.0
model.23.cv2.2.1.bn 0.9990234375 0.9990234375 0.0 0.0
model.23.cv3.0.0.0.bn 1.1171875 0.951171875 0.0545654296875 9.894371032714844e-06
model.23.cv3.0.0.1.bn 1.005859375 0.96630859375 0.05450439453125 9.59634780883789e-06
model.23.cv3.0.1.0.bn 1.12109375 0.9296875 0.059326171875 1.9431114196777344e-05
model.23.cv3.0.1.1.bn 1.1337890625 0.91748046875 0.2464599609375 0.0006527900695800781
model.23.cv3.1.0.0.bn 1.0263671875 0.98681640625 0.0114593505859375 5.960464477539063e-08
model.23.cv3.1.0.1.bn 1.00390625 0.9775390625 0.0640869140625 1.1920928955078125e-07
model.23.cv3.1.1.0.bn 1.0478515625 0.970703125 0.0250396728515625 9.5367431640625e-07
model.23.cv3.1.1.1.bn 1.1005859375 0.9873046875 0.2313232421875 0.0007276535034179688
model.23.cv3.2.0.0.bn 1.0 0.998046875 0.0006380081176757812 0.0
model.23.cv3.2.0.1.bn 0.99951171875 0.9990234375 0.0020084381103515625 0.0
model.23.cv3.2.1.0.bn 1.0009765625 0.9951171875 0.003658294677734375 0.0
model.23.cv3.2.1.1.bn 1.0302734375 0.9921875 0.193115234375 0.0001518726348876953
注意:此代码只适用于YOLO11官方架构,不适用于其他版本和自创模型,否则需要重构prune_yolo11.py文件!
运行prune_yolo11.py之后,会生成一个pt文件,该文件是剪枝后的:
此时我们验证一下剪枝后的文件:
YOLO11m summary (fused): 125 layers, 17,723,891 parameters, 13,270 gradients, 56.4 GFLOPs
val: Scanning /home/hairou/ctc/yolo11/Dataset/labels/val.cache... 2000 images, 65 backgrounds, 0 corrupt: 100%|██████████| 2065/2065 [00:00<?, ?it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 130/130 [00:28<00:00, 4.53it/s]
all 2065 4914 0.5 0.000247 0.25 0.2
shallow_box_rgb 2000 2027 1 0.000493 0.5 0.4
shallow_half_box_rgb 1823 2887 0 0 0 0
Speed: 0.1ms preprocess, 3.6ms inference, 0.0ms loss, 0.3ms postprocess per image
可以看到参数量和计算量均有所下降,但是模型精度大大降低,所以需要第三步回调训练。
笔者注:按此代码保存的文件,生成的last_prune.pt(67.85MB)文件比原best.pt(38.63MB)大属于正常现象。
三. 回调训练(20250411更新)
回调训练和普通训练的区别在于:回调训练需要加载第二步剪枝好的pt模型,但是YOLO框架默认加载原YAML文件,因此如果直接使用如下代码,最后加载的还是原来的YOLO11m文件:
# 经典的错误,标准的零分
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
from ultralytics.models import RTDETR
if __name__ == '__main__':
# model = RTDETR(r'ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
model = YOLO(r"runs/train/Constrained Training YOLO11m/weights/last_prune.pt")
model.train(data=r'own.yaml',
cache=False,
imgsz=640,
epochs=30,
single_cls=False, # 是否是单类别检测
batch=16,
close_mosaic=10,
workers=0,
device=0,
optimizer='SGD',
amp=True,
project='runs/train',
name='exp',
)
那么怎么更改这个默认加载的逻辑呢?
第一步:将约束训练中ultralytics/engine/trainer.py中的BaseTrainer类_do_train方法中添加的代码删除(从略)
第二步:在ultralytics/engine/trainer.py中的BaseTrainer类setup_model方法中添加self.model = weights:
第三步:在ultralytics/engine/model.py中,修改Model类的train方法:
# ultralytics/engine/model.py 中实现自定义结构模型加载
print("-----------------------------------")
print(f"\033[1;32mINFO\033[0m: custom_model is True, load custom model. ")
for name, param in self.model.named_parameters():
if "dfl" in name:
param.requires_grad = False # 冻结
else:
param.requires_grad = True # 解冻其他层
self.trainer.model.model = self.model.model
最后我们再启动训练,可以看到模型训练内存占用(5.52G)相比于前面的约束训练(7.79G)明显减小:
最后训练出的pt文件也只有33.98MB,相比于原模型38.63MB有明显下降:
剪枝前:231 layers, 20,054,550 parameters, 20,054,534 gradients, 68.2 GFLOPs
剪枝后:125 layers, 17,723,891 parameters, 0 gradients, 56.4 GFLOPs