YOLOv8 基于BN层的通道剪枝

YOLOv8 基于BN层的通道剪枝

1. 稀疏约束训练

在损失项中增加对BN层的缩放系数 γ \gamma γ和偏置项 β \beta β的稀疏约束, λ \lambda λ系数越大,稀疏约束越严重
L = ∑ ( x , y ) l ( f ( x ) , y ) + λ 1 ∑ γ g ( γ ) + λ 2 ∑ β g ( β ) L = \sum_{(x,y)}l(f(x),y)+\lambda_1 \sum_{\gamma}g(\gamma)+\lambda_2 \sum_{\beta}g(\beta) L=(x,y)l(f(x),y)+λ1γg(γ)+λ2βg(β)
对于 L 1 L_1 L1 稀疏约束,有:
g ( γ ) = ∣ γ ∣ , g ( β ) = ∣ β ∣ g(\gamma)=|\gamma|,\quad g(\beta) = |\beta| g(γ)=γ,g(β)=β
如果直接修改YOLOv8的损失,不方便控制L只传导对BN的参数更新,因此,采用修改BN的梯度的方式修改。

相对于原始的梯度项,BN的缩放系数和偏置项会增加以下梯度:
Δ γ = ∂ ( λ 1 ∗ g ( γ ) ) ∂ γ = λ 1 ∗ s i g n ( γ ) Δ β = ∂ ( λ 2 ∗ g ( β ) ) ∂ β = λ 2 ∗ s i g n ( β ) \Delta\gamma = \frac{\partial (\lambda_1*g(\gamma))}{\partial \gamma} = \lambda_1*sign(\gamma) \\ \Delta\beta = \frac{\partial (\lambda_2*g(\beta))}{\partial \beta} = \lambda_2*sign(\beta) Δγ=γ(λ1g(γ))=λ1sign(γ)Δβ=β(λ2g(β))=λ2sign(β)
在训练过程中,逐渐减小 λ 1 \lambda_1 λ1参数,减小对 γ \gamma γ的约束(稳定训练、增强训练和重调的一致性
λ 1 = 0.01 ∗ ( 1 − 0.9 ∗ e n e ) \lambda_1 = 0.01*(1-0.9*\frac{e}{ne}) λ1=0.01(10.9nee)
对于YOLOv8,我们只需要找到梯度更新的地方,然后修改即可。

修改YOLOv8代码:ultralytics/engine/trainer.py-390行

# Backward
self.scaler.scale(self.loss).backward()

# ========== 新增 ==========
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# ========== 新增 ==========

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html

然后执行如下代码开启训练:

yolo = YOLO("yolov8n.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)

2. 剪枝

稀疏训练之后呢,我们得到了一个best.pt和last.pt,由于需要微调,基于last.pt相对更好。

YOLOv8的结构如下:

在这里插入图片描述

该结构中每一个Conv层中均包含一个BN层,对BN进行通道剪枝的时候,一方面需要剪掉Conv的输出通道数和对应的权重,另一方面需要剪掉下一层Conv的输入通道数和权重。

由于前三层0,1,2通道数较少因此每个通道对特征提取均较为重要,因此不剪枝

由于第4,6,9层的输出涉及head层中的通道拼接,结构复杂不便于剪枝,因此不剪枝

此外,其它Conv非连续的部分,例如C2f内部Conv层与Bottleneck之间有split操作,FPN中C2f之间穿插了Upsample,Concat等操作。这些部分我们也不剪枝。

这样来看,我们可以剪枝的地方包括:

模块间

Backbone:

Conv(3) => C2f(4)
Conv(5) => C2f(6)
Conv(7) => C2f(8)
C2f(8)  => SPPF(9)

Head:

C2f(15) => [Conv(16),Conv(Detect.cv2[0][0]),Conv(Detect.cv3[0][0])]
C2f(18) => [Conv(19),Conv(Detect.cv2[1][0]),Conv(Detect.cv3[1][0])]
C2f(21) => [Conv(Detect.cv2[2]),Conv(Detect.cv3[2])]

模块内

除了上述模块之间的衔接,模块内的连续Conv主要包括两部分

Bottleneck in C2f

Conv(Bottleneck.cv1) => Conv(Bottleneck.cv2)

cv2, cv3 in Detect

Conv(Detect.cv2[0][0]) => Conv(Detect.cv2[0][1])
Conv(Detect.cv2[0][1]) => Conv2d(Detect.cv2[0][2])
Conv(Detect.cv3[0][0]) => Conv(Detect.cv3[0][1])
Conv(Detect.cv3[0][1]) => Conv2d(Detect.cv3[0][2])

Conv(Detect.cv2[1][0]) => Conv(Detect.cv2[1][1])
Conv(Detect.cv2[1][1]) => Conv2d(Detect.cv2[1][2])
Conv(Detect.cv3[1][0]) => Conv(Detect.cv3[1][1])
Conv(Detect.cv3[1][1]) => Conv2d(Detect.cv3[1][2])

Conv(Detect.cv2[2][0]) => Conv(Detect.cv2[2][1])
Conv(Detect.cv2[2][1]) => Conv2d(Detect.cv2[2][2])
Conv(Detect.cv3[2][0]) => Conv(Detect.cv3[2][1])
Conv(Detect.cv3[2][1]) => Conv2d(Detect.cv3[2][2])

剪枝代码如下:

import torch
from ultralytics import YOLO
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect


def prune_conv(conv1: Conv, conv2: Conv, threshold=0.01):
    # 剪枝top-bottom conv结构
    # 首先,剪枝conv1的bn层和conv层
    # 获取conv1的bn层权重和偏置参数作为剪枝的依据
    gamma = conv1.bn.weight.data.detach()
    beta = conv1.bn.bias.data.detach()
    # 索引列表,用于存储剪枝后保留的参数索引
    keep_idxs = []
    local_threshold = threshold
    # 保证剪枝后的通道数不少于8,便于硬件加速
    while len(keep_idxs) < 8:
        # 取绝对值大于阈值的参数对应的索引
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        # 降低阈值
        local_threshold = local_threshold * 0.5
    # print(local_threshold)
    # 剪枝后的通道数
    n = len(keep_idxs)
    # 更新BN层参数
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    # 更新conv层权重
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    # 更新conv层输出通道数
    conv1.conv.out_channels = n
    # 更新conv层偏置,如果存在的话
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    # 然后,剪枝conv2的conv层
    if not isinstance(conv2, list):
        conv2 = [conv2]

    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            # 更新输入通道数
            conv.in_channels = n
            # 更新权重
            conv.weight.data = conv.weight.data[:, keep_idxs]


def prune_module(m1, m2, threshold=0.01):
    # 剪枝 模块间衔接处结构,m1需要获取模块的bottom conv,m2需要获取模块的top conv
    # 打印出m1和m2的名字
    print(m1.__class__.__name__, end="->")
    if isinstance(m2, list):
        print([item.__class__.__name__ for item in m2])
    else:
        print(m2.__class__.__name__)
    if isinstance(m1, C2f):  # C2f as a top conv
        m1 = m1.cv2
    if not isinstance(m2, list):  # m2 is just one module
        m2 = [m2]
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
    prune_conv(m1, m2, threshold)


def prune():
    # Load a model
    yolo = YOLO("last.pt")
    model = yolo.model
    # 统计所有的BN层权重和偏置参数
    ws = []
    bs = []

    for name, m in model.named_modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            w = m.weight.abs().detach()
            b = m.bias.abs().detach()
            ws.append(w)
            bs.append(b)
            # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

    # 保留80%的参数
    factor = 0.8
    ws = torch.cat(ws)
    # 从大到小排序,取80%的参数对应的阈值
    threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
    print(threshold)

    # 先剪枝整个网络bottleneck模块内部的结构
    for name, m in model.named_modules():
        if isinstance(m, Bottleneck):
            prune_conv(m.cv1, m.cv2, threshold)

    # 再剪枝backbone模块间衔接结构
    seq = model.model
    for i in range(3, 9):
        if i in [6, 4, 9]: continue
        prune_module(seq[i], seq[i + 1], threshold)

    # 再剪枝Head模块间衔接结构
    # Head模块间剪枝包括两部分,一部分是相邻下层连接,一部分是跨层到Detect层的输出
    # 从last_inputs到colasts是相邻下层连接,从last_inputs到detect是跨层到最后的输出
    detect: Detect = seq[-1]
    last_inputs = [seq[15], seq[18], seq[21]]
    colasts = [seq[16], seq[19], None]
    for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
        prune_module(last_input, [colast, cv2[0], cv3[0]], threshold)
        # 剪枝Detect层内部模块间衔接结构
        prune_module(cv2[0], cv2[1], threshold, )
        prune_module(cv2[1], cv2[2], threshold)
        prune_module(cv3[0], cv3[1], threshold)
        prune_module(cv3[1], cv3[2], threshold)

    # 设置所有参数为可训练,为retrain做准备
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    # 保存剪枝后的模型
    yolo.save("prune.pt")


if __name__ == '__main__':
    prune()

3. 重调

剪枝完成后需要进行重调,此时我们需要先取消稀疏约束,即将trainer中的约束代码重新注释掉

随后,重调的时候,需要防止代码重新根据yaml文件生成模型,而是直接读取权重模型

修改:在ultralytics/engine/model.py-808行后添加

self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# 新增 ===================================
self.trainer.model.model = self.model.model
# 新增 ===================================
self.model = self.trainer.model

随后基于如下代码进行重调训练:

yolo = YOLO("prune.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)

4. 对比

我们可以对比一下稀疏训练的原模型、剪枝后的模型、重调后的模型的精度、参数、计算量

def compare_prune():
    # 统计压缩前后的参数量,精度,计算量
    yolo = YOLO("last.pt")
    before_results = yolo.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    yolo_prune = YOLO("prune.pt")
    prune_results = yolo_prune.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    yolo_retrain = YOLO("retrain.pt")
    retrain_results = yolo_retrain.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)

    # 打印压缩前后的参数量,精度,计算量
    n_l, n_p, n_g, flops = yolo.info()
    prune_n_l, prune_n_p, prune_n_g, prune_flops = yolo_prune.info()
    retrain_n_l, retrain_n_p, retrain_n_g, retrain_flops = yolo_retrain.info()
    acc = before_results.box.map
    prune_acc = prune_results.box.map
    retrain_acc = retrain_results.box.map
    print(f"{'':<10}{'Before':<10}{'Prune':<10}{'Retrain':<10}")
    print(f"{'Params':<10}{n_p:<10}{prune_n_p:<10}{retrain_n_p:<10}")
    print(f"{'FLOPs':<10}{flops:<10}{prune_flops:<10}{retrain_flops:<10}")
    print(f"{'Acc':<10}{acc:<10}{prune_acc:<10}{retrain_acc:<10}")
  • 30
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柳成荫~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值