结构重参数化详解。(bn+conv)与(conv+bn)的融合

原理

如何理解结构重参数化?即 把结构 参数化 ,在训练的时候使用一种复杂的结构,在训练结束后,将多个结构的权重合并,从而在推理时采用另外一种更简单结构加载权重,达到减少计算量与参数量的目的。
比如下面的这个图,如果conv1和conv2的stride一致,那么conv1和conv2可以合并为一个卷积操作。

在这里插入图片描述
如何合并呢?将conv1的权重和conv2的权重相加即可合并为一个卷积层,如果kernel大小不一致,在小的kernel周围填充一圈圈的0,直到大小一致,即可完成合并。这里需要注意在推理阶段需要使用大小不一致的padding以确保生成的特征图一致,否则无法相加。这也是stride不一致的情况下无法合并的原因。而在推理阶段,已经确保将kernel大小一致,那么padding也一致
在公式上可表达为:
y = ( W 1 × x + B 1 ) + ( W 2 × x + B 2 ) y = ( W fuse × x + B fuse ) W fuse = W 1 + W 2 B fuse = B 1 + B 2 y = (W_1 \times x + B_1) + (W_2 \times x + B_2) \\ y = (W_{\text{fuse}} \times x + B_{\text{fuse}}) \\ W_{\text{fuse}} = W_1 + W_2 \\ B_{\text{fuse}} = B_1 + B_2 y=(W1×x+B1)+(W2×x+B2)y=(Wfuse×x+Bfuse)Wfuse=W1+W2Bfuse=B1+B2

融合后的结构:
在这里插入图片描述
那么这就是结构重参数化原理,结构重参数化不仅可以合并水平方向上的分支,也可以合并垂直方向上的操作。残差结构可以被视为卷积核大小为1,并且值为1的卷积操作,那么同样也可以被合并

哪种情况不能合并?

问题来了,有哪些操作不能合并?上面能合并的操作中均不包括激活函数,如果经过激活函数,那么y的计算公式由
y = ( W 1 × x + B 1 ) + ( W 2 × x + B 2 ) y = (W_1 \times x + B_1) + (W_2 \times x + B_2) y=(W1×x+B1)+(W2×x+B2)
变为了
y = ( W 1 × x + B 1 ) + activate ( W 2 × x + B 2 ) y = (W_1 \times x + B_1) + \text{activate}(W_2 \times x + B_2) \\ y=(W1×x+B1)+activate(W2×x+B2)
导致conv2由线性操作变为了非线形操作,就无法合并。如下图
在这里插入图片描述

带bn的合并结构

(conv + bn)

使用卷积层使用bn层,如下图。
在这里插入图片描述
这个结构在公式上可表达为:
y = bn ( conv ( x ) ) y = \text{bn}(\text{conv}(x)) \\ y=bn(conv(x))
复习一下conv和bn的公式,其中mean是平均值,var是方差,eps是避免分母为0,W、B是需要学习的参数。mean、var和eps是不需要学习的
conv = W conv × x + B conv bn = W bn x − mean var + eps + B BN \text{conv} = W_{\text{conv}} \times x + B_{\text{conv}} \\ \text{bn} = W_{\text{bn}} \frac{x - \text{mean}}{\sqrt{\text{var} + \text{eps}}} + B_{\text{BN}} \\ conv=Wconv×x+Bconvbn=Wbnvar+eps xmean+BBN
将conv带入到bn的x,可得下式
y = W b n × W c o n v v a r + e p s × x + B b n + W b n v a r + e p s × ( B c o n v − m e a n ) y = \frac{ W_{bn} \times W_{conv} }{\sqrt{var+eps}}\times x + B_{bn} + \frac{ W_{bn} }{\sqrt{var+eps}} \times (B_{conv} - mean) \\ y=var+eps Wbn×Wconv×x+Bbn+var+eps Wbn×(Bconvmean)
那么可以得到,融合后conv的Wfuse和Bfuse
W f u s e = W b n × W c o n v v a r + e p s B f u s e = B b n + W b n v a r + e p s × ( B c o n v − m e a n ) W_{fuse} = \frac{ W_{bn} \times W_{conv} }{\sqrt{var+eps}} \\ B_{fuse} = B_{bn} + \frac{ W_{bn} }{\sqrt{var+eps}} \times (B_{conv} - mean) Wfuse=var+eps Wbn×WconvBfuse=Bbn+var+eps Wbn×(Bconvmean)

(bn + conv)

使用bn层,使用conv层,如下图:

在这里插入图片描述
这个结构在公式上可表达为:
y = conv ( bn ( x ) ) y = \text{conv}(\text{bn}(x)) \\ y=conv(bn(x))
依旧贴一下bn和conv的公式
bn = W bn x − mean var + eps + B BN conv = W conv × x + B conv \text{bn} = W_{\text{bn}} \frac{x - \text{mean}}{\sqrt{\text{var} + \text{eps}}} + B_{\text{BN}} \\ \text{conv} = W_{\text{conv}} \times x + B_{\text{conv}} \\ bn=Wbnvar+eps xmean+BBNconv=Wconv×x+Bconv
将bn代入conv可得
y = W c o n v × ( W b n × ( x − m e a n v a r + e p s ) + B b n ) + B c o n v y = W c o n v × W b n v a r + e p s × x + ( − W c o n v × W b n v a r + e p s × m e a n + W c o n v × B b n + B c o n v ) y = W_{conv} \times (W_{bn}\times(\frac{x-mean}{\sqrt{var+eps}})+B_{bn})+B_{conv} \\ y = \frac{W_{conv} \times W_{bn}}{\sqrt{var+eps}}\times x + (- \frac{ W_{conv}\times W_{bn}}{\sqrt{var+eps}}\times mean + W_{conv} \times B_{bn} + B_{conv}) y=Wconv×(Wbn×(var+eps xmean)+Bbn)+Bconvy=var+eps Wconv×Wbn×x+(var+eps Wconv×Wbn×mean+Wconv×Bbn+Bconv)
即融合后conv的Wfuse、Bfuse为:
W f u s e = W c o n v × W b n v a r + e p s B f u s e = − W c o n v × W b n v a r + e p s × m e a n + W c o n v × B b n + B c o n v W_{fuse} = \frac{W_{conv} \times W_{bn}}{\sqrt{var+eps}} \\ B_{fuse} = - \frac{ W_{conv}\times W_{bn}}{\sqrt{var+eps}}\times mean + W_{conv} \times B_{bn} + B_{conv} Wfuse=var+eps Wconv×WbnBfuse=var+eps Wconv×Wbn×mean+Wconv×Bbn+Bconv

需要注意:

  1. 以上的公式中忽略了W和B的shape,即他们是矩阵,但在公式中仅以符号代表。先使用bn和后使用bn,bn的channels跟着in_channels 和out_channels走,并且先使用bn层的情况下,还需要考虑groups。稍微复杂一点点
  2. 通常在后面接bn的conv层中,不会添加Bconv

pytorch代码

如正常卷积层一样使用,重点是传入rbr_conv_kernel_list参数。是每个分支结构的卷积核大小。调用reparameterize()会自动合并多分支,结构重参数化后仅有一个conv和激活函数。详情可见https://github.com/balala8/FastViT_pytorch


class RepBlock(nn.Module):
    """
    MobileOne-style residual blocks, including residual joins and re-parameterization convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        groups: int = 1,
        inference_mode: bool = False,
        rbr_conv_kernel_list: List[int] = [7, 3],
        use_bn_conv: bool = False,
        act_layer: nn.Module = nn.ReLU,
        skip_include_bn: bool = True,
    ) -> None:
        """Construct a Re-parameterization module.

        :param in_channels: Number of input channels.
        :param out_channels: Number of output channels.
        :param stride: Stride for convolution.
        :param groups: Number of groups for convolution.
        :param inference_mode: Whether to use inference mode.
        :param rbr_conv_kernel_list: List of kernel sizes for re-parameterizable convolutions.
        :param use_bn_conv: Whether the bn is in front of conv, if false, conv is in front of bn
        :param act_layer: Activation layer.
        :param skip_include_bn: Whether to include bn in skip connection.
        """
        super(RepBlock, self).__init__()

        self.inference_mode = inference_mode
        self.groups = groups
        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.rbr_conv_kernel_list = sorted(rbr_conv_kernel_list, reverse=True)
        self.num_conv_branches = len(self.rbr_conv_kernel_list)
        self.kernel_size = self.rbr_conv_kernel_list[0]
        self.use_bn_conv = use_bn_conv
        self.skip_include_bn = skip_include_bn

        self.activation = act_layer()

        if inference_mode:
            self.reparam_conv = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=self.kernel_size,
                stride=stride,
                padding=self.kernel_size // 2,
                groups=groups,
                bias=True,
            )
        else:
            # Re-parameterizable skip connection
            if out_channels == in_channels and stride == 1:
                if self.skip_include_bn:
                    # Use residual connections that include BN
                    self.rbr_skip = nn.BatchNorm2d(num_features=in_channels)
                else:
                    # Use residual connections
                    self.rbr_skip = nn.Identity()
            else:
                # Use residual connections
                self.rbr_skip = None

            # Re-parameterizable conv branches
            rbr_conv = list()
            for kernel_size in self.rbr_conv_kernel_list:
                if self.use_bn_conv:
                    rbr_conv.append(
                        self._bn_conv(
                            in_chans=in_channels,
                            out_chans=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            padding=kernel_size // 2,
                            groups=groups,
                        )
                    )
                else:
                    rbr_conv.append(
                        self._conv_bn(
                            in_chans=in_channels,
                            out_chans=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            padding=kernel_size // 2,
                            groups=groups,
                        )
                    )

            self.rbr_conv = nn.ModuleList(rbr_conv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply forward pass."""
        # Inference mode forward pass.
        if self.inference_mode:
            return self.activation(self.reparam_conv(x))

        # Multi-branched train-time forward pass.
        # Skip branch output
        identity_out = 0
        if self.rbr_skip is not None:
            identity_out = self.rbr_skip(x)

        # Other branches
        out = identity_out
        for ix in range(self.num_conv_branches):
            out = out + self.rbr_conv[ix](x)
        return self.activation(out)

    def reparameterize(self):
        """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
        https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
        architecture used at training time to obtain a plain CNN-like structure
        for inference.
        """
        if self.inference_mode:
            return
        kernel, bias = self._get_kernel_bias()
        self.reparam_conv = nn.Conv2d(
            in_channels=self.rbr_conv[0].conv.in_channels,
            out_channels=self.rbr_conv[0].conv.out_channels,
            kernel_size=self.rbr_conv[0].conv.kernel_size,
            stride=self.rbr_conv[0].conv.stride,
            padding=self.rbr_conv[0].conv.padding,
            dilation=self.rbr_conv[0].conv.dilation,
            groups=self.rbr_conv[0].conv.groups,
            bias=True,
        )
        self.reparam_conv.weight.data = kernel
        self.reparam_conv.bias.data = bias

        # Delete un-used branches
        for para in self.parameters():
            para.detach_()
        self.__delattr__("rbr_conv")
        if hasattr(self, "rbr_skip"):
            self.__delattr__("rbr_skip")

        self.inference_mode = True

    def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Method to obtain re-parameterized kernel and bias.
        Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83

        :return: Tuple of (kernel, bias) after fusing branches.
        """

        # get weights and bias of skip branch
        kernel_identity = 0
        bias_identity = 0
        if self.rbr_skip is not None:
            kernel_identity, bias_identity = self._fuse_skip_tensor(self.rbr_skip)

        # get weights and bias of conv branches
        kernel_conv = 0
        bias_conv = 0
        for ix in range(self.num_conv_branches):
            if self.use_bn_conv:
                _kernel, _bias = self._fuse_bn_conv_tensor(self.rbr_conv[ix])
            else:
                _kernel, _bias = self._fuse_conv_bn_tensor(self.rbr_conv[ix])
            # pad kernel
            if _kernel.shape[-1] < self.kernel_size:
                pad = (self.kernel_size - _kernel.shape[-1]) // 2
                _kernel = torch.nn.functional.pad(_kernel, [pad, pad, pad, pad])

            kernel_conv += _kernel
            bias_conv += _bias

        kernel_final = kernel_conv + kernel_identity
        bias_final = bias_conv + bias_identity
        return kernel_final, bias_final

    def _fuse_skip_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param branch: skip branch, maybe include bn layer
        :return: Tuple of (kernel, bias) after fusing batchnorm.
        """

        if not hasattr(self, "id_tensor"):
            input_dim = self.in_channels // self.groups
            kernel_value = torch.zeros(
                (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
                dtype=self.rbr_conv[0].conv.weight.dtype,
                device=self.rbr_conv[0].conv.weight.device,
            )
            for i in range(self.in_channels):
                kernel_value[
                    i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
                ] = 1
            self.id_tensor = kernel_value
        if isinstance(branch, nn.Identity):
            kernel = self.id_tensor
            return kernel, torch.zeros(
                (self.in_channels),
                dtype=self.rbr_conv[0].conv.weight.dtype,
                device=self.rbr_conv[0].conv.weight.device,
            )
        else:
            assert isinstance(
                branch, nn.BatchNorm2d
            ), "Make sure the module in skip is nn. BatchNorm2d"
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
            std = (running_var + eps).sqrt()
            t = (gamma / std).reshape(-1, 1, 1, 1)
            return kernel * t, beta - running_mean * gamma / std

    def _fuse_bn_conv_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
        """先bn,后conv

        :param branch:
        :return: Tuple of (kernel, bias) after fusing batchnorm.
        """
        kernel = branch.conv.weight
        running_mean = branch.bn.running_mean
        running_var = branch.bn.running_var
        gamma = branch.bn.weight
        beta = branch.bn.bias
        eps = branch.bn.eps
        std = (running_var + eps).sqrt()
        t = gamma / std
        t = torch.stack([t] * (kernel.shape[0] * kernel.shape[1]//t.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
        t_beta = torch.stack([beta] * (kernel.shape[0] * kernel.shape[1]//beta.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
        t_running_mean = torch.stack([running_mean] * (kernel.shape[0] * kernel.shape[1]//running_mean.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
        return kernel * t, torch.sum(
            kernel
            * (
                t_beta - t_running_mean * t
            ),
            dim=(1, 2, 3),
        )

    def _fuse_conv_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
        """First conv, then bn

        :param branch:
        :return: Tuple of (kernel, bias) after fusing batchnorm.
        """

        kernel = branch.conv.weight
        running_mean = branch.bn.running_mean
        running_var = branch.bn.running_var
        gamma = branch.bn.weight
        beta = branch.bn.bias
        eps = branch.bn.eps

        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def _conv_bn(
        self,
        in_chans: int,
        out_chans: int,
        kernel_size: int,
        stride: int,
        padding: int,
        groups: int,
    ) -> nn.Sequential:
        """First conv, then bn

        :param kernel_size: Size of the convolution kernel.
        :param padding: Zero-padding size.
        :return: Conv-BN module.
        """
        mod_list = nn.Sequential()
        mod_list.add_module(
            "conv",
            nn.Conv2d(
                in_channels=in_chans,
                out_channels=out_chans,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=False,
            ),
        )
        mod_list.add_module("bn", nn.BatchNorm2d(num_features=out_chans))
        return mod_list

    def _bn_conv(
        self,
        in_chans: int,
        out_chans: int,
        kernel_size: int,
        stride: int,
        padding: int,
        groups: int,
    ) -> nn.Sequential:
        """Add bn first, then conv"""
        mod_list = nn.Sequential()
        mod_list.add_module("bn", nn.BatchNorm2d(num_features=in_chans))
        mod_list.add_module(
            "conv",
            nn.Conv2d(
                in_channels=in_chans,
                out_channels=out_chans,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=False,
            ),
        )
        return mod_list
<think>嗯,用户想了解YOLOv5 Head的网络结构和工作原理。首先,我需要回顾一下YOLO系列的基本知识,尤其是YOLOv5。YOLO(You Only Look Once)是一种单阶段的目标检测算法,速度快且精度较高。YOLOv5是Ultralytics公司开发的一个版本,虽然官方没有正式论文,但开源代码和文档很详细。 YOLOv5的整体结构分为Backbone、Neck和Head三个部分。用户已经提到了引用中的内容,说明Backbone是CSPDarknet,Neck是PANet结构,而Head部分负责最终的检测。我需要重点解析Head的结构。 首先,YOLOv5 Head是典型的解耦头(Decoupled Head),这YOLOv3/v4不同。解耦头将分类和回归任务分开处理,可能提升检测精度。具体结构可能包括卷积层、激活函数、批归一化等组件。 接下来,Head的三个核心功能:生成预测框、计算objectness分数、分类。每个检测头对应不同尺度的特征图,比如大、中、小目标分别用不同尺度的特征图检测。每个特征图的每个锚点会预测多个边界框,参数包括中心坐标、宽高、objectness和类别概率。 然后,需要解释边界框的编码方式,比如使用基于锚框的偏移量预测,以及sigmoid函数处理中心坐标。objectness分数表示存在目标的置信度,而分类使用sigmoid激活函数支持多标签分类,这点可能和YOLOv3的softmax不同。 损失函数方面,YOLOv5 Head的损失包括定位损失(CIoU或GIoU)、置信度损失(二元交叉熵)和分类损失(二元交叉熵)。这些损失函数的组合优化了模型性能。 此外,用户可能想知道Head的优化点,比如解耦头设计和动态锚框调整。动态锚框根据数据集自动调整锚框尺寸,提升检测效果。 需要确认引用内容是否准确。引用[1]提到了detect.py的源码和参数,可能涉及Head的具体实现。引用[2]提到YOLOv8的结构之前版本类似,所以YOLOv5的Head结构可以作为参考。在回答中需要正确引用这些资料,确保信息准确。 最后,生成的相关问题要围绕Head的结构、工作原理、优化方法以及其他版本的对比展开,帮助用户深入理解。需要检查是否有遗漏的关键点,比如不同尺度的检测头如何协同工作,或者具体的层结构细节,确保回答全面且准确。</think>### YOLOv5 Head网络结构及工作原理详解 #### 一、YOLOv5 Head核心结构 YOLOv5 Head是目标检测的最终预测模块,其结构包含**解耦预测头**和**多尺度预测分支**,具体实现如下: $$ \text{Head} = \text{Conv2d}(3\times3) \rightarrow \text{BN} \rightarrow \text{LeakyReLU} \rightarrow \text{Conv2d}(1\times1) $$ ```python class Detect(nn.Module): def __init__(self, nc=80, anchors=()): super().__init__() self.nc = nc # 类别数 self.no = nc + 5 # 输出维度(坐标4+obj1+cls) self.nl = len(anchors) # 检测层数量 self.anchors = anchors # 构建预测卷积层 self.m = nn.ModuleList(nn.Conv2d(x, self.no * 3, 1) for x in [256, 512, 1024]) # 对应三个尺度 ``` #### 二、工作原理解析 1. **多尺度预测机制** - 通过$3$个不同分辨率的特征图进行检测: - $76\times76$(浅层特征):检测小目标 - $38\times38$(中层特征):检测中目标 - $19\times19$(深层特征):检测大目标[^1] 2. **预测参数解码** 每个网格预测$3$个锚框,每个预测结果包含: $$ (t_x, t_y, t_w, t_h, obj\_score, cls_1, ..., cls_n) $$ - 坐标解码公式: $$ \begin{cases} b_x = 2\sigma(t_x) - 0.5 + c_x \\ b_y = 2\sigma(t_y) - 0.5 + c_y \\ b_w = w_a(2\sigma(t_w))^2 \\ b_h = h_a(2\sigma(t_h))^2 \end{cases} $$ 其中$\sigma$为sigmoid函数,$(c_x,c_y)$为网格偏移量,$(w_a,h_a)$为锚框尺寸[^2] 3. **分类置信度** - 分类预测使用sigmoid激活而非softmax,支持多标签检测 - Objectness分数预测使用二元交叉熵损失 #### 三、关键创新点 1. **解耦头设计** 将分类和回归任务分离处理,相比耦合头提升约1.1 AP 2. **动态锚框机制** 在训练前自动计算最佳锚框尺寸: ```python def check_anchors(self, dataset, thr=4.0): # 自动计算锚框标注框的最佳匹配比例 metrics = kmean_anchors(dataset) ``` 3. **损失函数组合** - CIoU Loss:改进的定位损失函数 - Focal Loss:处理类别不平衡问题 - 总损失: $$ L = \lambda_{box}L_{CIoU} + \lambda_{obj}L_{obj} + \lambda_{cls}L_{cls} $$
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值