目 录
一、约束训练(constrained training)
二、剪枝
三:回调训练
本文在官方yolov8.2的基础上进行剪枝。yolov8官网链接:ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite (github.com)https://github.com/ultralytics/ultralytics
本文的代码已上传GitHub,链接:yolov8_prune
一、约束训练(constrained training)
1.理由
剪枝(Pruning)是一种神经网络压缩技术,它的目的是减少网络的复杂度,提高计算效率,同时尽量保持模型的性能。剪枝通常涉及移除网络中的一些权重或神经元,从而减少模型的大小和计算需求。然而,剪枝可能会对模型的泛化能力产生负面影响,因为移除的权重可能包含了对模型性能重要的信息。
为了缓解这个问题,约束化训练(Constrained Training)被引入。约束化训练是在训练过程中加入额外的约束来引导模型学习,以减少剪枝带来的负面影响。在BN(Batch Normalization)层添加L1正则化是一种常见的约束化训练方法。
L1正则化倾向于产生稀疏的权重矩阵,这意味着在训练后,模型的BN层权重矩阵中会有更多接近于零的值。在下一步剪枝的时候,就把这些有很多0值的矩阵去掉,这样处理后的模型对预测结果影响不大。
2.做法
a.在./ultralytics/engine/trainer.py中添加以下内容:
# Backward
self.scaler.scale(self.loss).backward()
# ========== added(新增) ==========
# 1 constrained training
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# ========== added(新增) ==========
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
b.启动训练(/yolov8/train.py):
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
def main():
model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')
model.train(data="data.yaml", amp=False, imgsz=640, epochs=100, batch=20, device=0, workers=0)
if __name__ == '__main__':
main()
二.剪枝
1.目的:结合L1或L0正则化来促进权重的稀疏性,然后剪除那些接近零的权重。
2.做法:在/yolov8/下新建文件prune.py,其中有三个参数:yolo加载的是约束训练后的模型;res_dir设置保存剪枝后的模型地址;fractor为剪枝率。具体内容如下:
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy
# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75
yolo.info()
model = yolo.model
ws = []
bs = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)
def prune_conv(conv1: Conv, conv2: Conv):
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
keep_idxs = []
local_threshold = threshold
while len(keep_idxs) < 8:
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
local_threshold = local_threshold * 0.5
n = len(keep_idxs)
# n = max(int(len(idxs) * 0.8), p)
# print(n / len(gamma) * 100)
# scale = len(idxs) / n
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
conv1.conv.out_channels = n
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is not None:
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
conv.in_channels = n
conv.weight.data = conv.weight.data[:, keep_idxs]
def prune(m1, m2):
if isinstance(m1, C2f): # C2f as a top conv
m1 = m1.cv2
if not isinstance(m2, list): # m2 is just one module
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C2f) or isinstance(item, SPPF):
m2[i] = item.cv1
prune_conv(m1, m2)
for name, m in model.named_modules():
if isinstance(m, Bottleneck):
prune_conv(m.cv1, m.cv2)
seq = model.model
for i in range(3, 9):
if i in [6, 4, 9]: continue
prune(seq[i], seq[i + 1])
detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
prune(last_input, [colast, cv2[0], cv3[0]])
prune(cv2[0], cv2[1])
prune(cv2[1], cv2[2])
prune(cv3[0], cv3[1])
prune(cv3[1], cv3[2])
for name, p in yolo.model.named_parameters():
p.requires_grad = True
#yolo.val(workers=0) # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx") # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100) # 剪枝后直接训练微调
ckpt = {
'epoch': -1,
'best_fitness': None,
'model': yolo.ckpt['ema'],
'ema': None,
'updates': None,
'optimizer': None,
'train_args': yolo.ckpt["train_args"], # save as dict
'date': None,
'version': '8.0.142'}
torch.save(ckpt, res_dir)
最后为了保存的模型占用内存更小,重写了一下ckpt。
三:回调训练
1.目的:剪枝可能会暂时降低模型的性能,因为一些有用的权重被移除。回调训练可以帮助模型调整剩余的权重,以补偿被剪枝掉的权重,从而恢复或甚至提高模型的性能。
2.做法:
a.将先前在./ultralytics/engine/trainer.py中添加的L1正则化部分注释掉:
# Backward
self.scaler.scale(self.loss).backward()
# # ========== added(新增) ==========
# # 1 constrained training
# l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
# for k, m in self.model.named_modules():
# if isinstance(m, nn.BatchNorm2d):
# m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
# m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# # ========== added(新增) ==========
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
b.在./ultralytics/engine/trainer.py中的函数setup_model修改:
def setup_model(self):
"""Load/create/download model for any task."""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model, weights = self.model, None
ckpt = None
if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model)
cfg = weights.yaml
else:
cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
# ========== added(新增) ==========
# 2 finetune 回调训练
self.model = weights
# ========== added(新增) ==========
return ckpt
c.我们再次启动训练(/yolov8/train.py):
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
def main():
model = YOLO('runs/detect/prune/weights/prune.pt')
model.train(data="data.yaml", imgsz=640, epochs=100, batch=20, device=0, workers=0)
if __name__ == '__main__':
main()