Torch 论文复现:结构重参数化 RepVGGBlock

在 ShuffleNet v2 中提出了轻量化网络的 4 大设计准则:

  • 输入输出通道相同时,MAC 最小
  • FLOPs 相同时,分组数过大的分组卷积会增加 MAC
  • 碎片化操作 (多分支结构) 对并行加速不友好
  • 逐元素操作带来的内存和耗时不可忽略

近年来,卷积神经网络的结构已经变得越来越复杂;得益于多分支结构良好的收敛能力,多分支结构越来越流行

但是,使用多分支结构的时候,一方面无法有效地利用并行加速,另一方面增加了 MAC

ece5470014b045d2b2201386f4fe8640.png

为了使简单结构也能达到与多分支结构相当的精度,在训练 RepVGG 时使用多分支结构 (3×3 卷积 + 1×1 卷积 + 恒等映射),以借助其良好的收敛能力;在推理、部署时利用重参数化技术将多分支结构转化为单路结构,以借助简单结构极致的速度

baaf2b1eed55434c87ed85d6be574261.png

重参数化

训练所使用的多分支结构中,每一个分支中均有一个 BN 层

BN 层有四个运算时使用的参数:mean、var、weight、bias,对输入 x 执行以下变换:

gif.latex?BN%28x%29%3Dweight%20%5Ccdot%20%5Cfrac%7Bx-mean%7D%7B%5Csqrt%7Bvar%7D%7D+bias

转化为 gif.latex?BN%28x%29%20%3D%20w_%7Bbn%7D%20%5Ccdot%20x%20+b_%7Bbn%7D 的形式时:

gif.latex?w_%7Bbn%7D%3D%5Cfrac%7Bweight%7D%7B%5Csqrt%7Bvar%7D%7D%2C%5C%20b_%7Bbn%7D%3Dbias-%5Cfrac%7Bweight%5Ccdot%20mean%7D%7B%5Csqrt%7Bvar%7D%7D

import torch
from torch import nn


class BatchNorm(nn.BatchNorm2d):

    def unpack(self, detach=False):
        mean, bias = self.running_mean, self.bias
        std = (self.running_var + self.eps).float().sqrt().type_as(mean)
        weight = self.weight / std
        eq_param = weight, bias - weight * mean
        return tuple(map(lambda x: x.data, eq_param)) if detach else eq_param


bn = BatchNorm(8).eval()
# 初始化随机参数
bn.running_mean.data, bn.running_var.data, bn.weight.data, bn.bias.data = torch.rand([4, 8])

image = torch.rand([1, 8, 1, 1])
print(bn(image).view(-1))
# 将 BN 的参数转化为 w, b 形式
weight, bias = bn.unpack()
print(image.view(-1) * weight + bias)

因为 BN 层会拟合每一个通道的偏置,所以将卷积层和 BN 层连接在一起使用时,卷积层不使用偏置,其运算可以表示为:

gif.latex?Conv%28x%29%3Dw_%7Bc%7D*x

gif.latex?BN%28Conv%28x%29%29%3Dw_%7Bbn%7Dw_%7Bc%7D*x+b_%7Bbn%7D

可见,卷积层和 BN 层可以等价于一个带偏置的卷积层

8f57be63ffb34c4d97ef57f4cfea131b.png

而恒等映射亦可等价于 1×1 卷积:

  • 对于 nn.Conv2d(c1, c2, kernel_size=1),其参数的 shape 为 [c2, c1, 1, 1] —— 可看作 [c2, c1] 的线性层,以执行各个像素点的通道变换 (参考:Torch 二维多通道卷积运算方式)
  • 当 c1 = c2、且这个线性层为单位阵时,等价于恒等映射

1×1 卷积又可通过填充 0 表示成 3×3 卷积,所以该多分支结构的计算可表示为:

gif.latex?BN_%7B3%20%5Ctimes%203%7D%28Conv_%7B3%20%5Ctimes%203%7D%28x%29%29%3Dw_3*x+b_3

gif.latex?BN_%7B1%20%5Ctimes%201%7D%28Conv_%7B1%20%5Ctimes%201%7D%28x%29%29%3Dw_1*x+b_1

gif.latex?BN_%7Bid%7D%28Conv_%7Bid%7D%28x%29%29%3Dw_o*x+b_0

gif.latex?y%3D%28w_3+w_1+w_0%29*x+%28b_3+b_1+b_0%29

从而可以等价成一个新的 3×3 卷积 (该结论亦可推广到分组卷积、5×5 卷积)

在 NVIDIA 1080Ti 上进行速度测试,以 [32, 2048, 56, 56] 的图像输入卷积核得到同通道同尺寸的输出,3×3 卷积每秒浮点运算量最多

e6443a4de31d464aa3b924aa2db7b12d.png

结构复现

参考代码:https://github.com/DingXiaoH/RepVGG

我对论文中的源代码进行了重构,目的是增强其可读性、易用性 (将重参数化计算写入类方法,可方便地操作集成模型),并支持更大尺寸的卷积核的合并 (如 7 × 7 和 3 × 3 合并)

同时,我将重参数化技术迁移到了简单的 CBS 模块中 (Conv - BN - SiLU),封装成 Conv 类。Conv 的类方法 reparam 可以将集成模型中的所有卷积层和 BN 层进行合并

在此基础上,我又为 RepConv 编写了类方法 reparam,用于多分支结构的合并。经过验证,该合并方法适用于所有情况,包括膨胀卷积、分组卷积

from collections import OrderedDict
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

sum_ = lambda x: sum(x[1:], x[0])


def auto_pad(k, s=1, d=1):
    # (k - 1) // 2 * d: 使卷积中心位移到图像的 [0, 0]
    # (s - 1) // 2: 使卷积中心位移到 [s/2, s/2]
    return max(0, (k - 1) // 2 * d - (s - 1) // 2)


class BatchNorm(nn.BatchNorm2d):

    def __init__(self, c1, s=1):
        super().__init__(c1)
        self.s = s

    def forward(self, x):
        return super().forward(x[..., ::self.s, ::self.s])

    def unpack(self, detach=False):
        mean, bias = self.running_mean, self.bias
        std = (self.running_var + self.eps).float().sqrt().to(mean)
        weight = self.weight / std
        eq_param = weight, bias - weight * mean
        return tuple(map(lambda x: x.data, eq_param)) if detach else eq_param


class Conv(nn.Module):
    ''' Conv - BN - Act'''
    deploy = property(fget=lambda self: isinstance(self.conv, nn.Conv2d))

    def __init__(self, c1, c2, k=3, s=1, g=1, d=1,
                 act: Optional[nn.Module] = nn.ReLU, ctrpad=True):
        super().__init__()
        assert k & 1, 'The convolution kernel size must be odd'
        # 深度可分离卷积
        if g == 'dw':
            g = c1
            assert c1 == c2, 'Failed to create DWConv'
        # nn.Conv2d 的关键字参数
        self._config = dict(
            in_channels=c1, out_channels=c2, kernel_size=k,
            stride=s, padding=auto_pad(k, s if ctrpad else 1, d), groups=g, dilation=d
        )
        self.conv = nn.Sequential(OrderedDict(
            conv=nn.Conv2d(**self._config, bias=False),
            bn=BatchNorm(c2)
        ))
        self.act = act() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.conv(x))

    @classmethod
    def reparam(cls, model: nn.Module):
        for m in filter(lambda m: isinstance(m, cls) and not m.deploy, model.modules()):
            kernel = m.conv.conv.weight.data
            bn_w, bn_b = m.conv.bn.unpack(detach=True)
            # 合并 nn.Conv 与 BatchNorm
            m.conv = nn.Conv2d(**m._config, bias=True)
            m.conv.weight.data, m.conv.bias.data = kernel * bn_w.view(-1, 1, 1, 1), bn_b


class RepConv(nn.Module):
    ''' RepConv
        k: 卷积核尺寸, 0 表示恒等映射'''
    deploy = property(fget=lambda self: isinstance(self.m, nn.Conv2d))

    def __init__(self, c1, c2, k=(0, 1, 3), s=1, g=1, d=1,
                 act: Optional[nn.Module] = nn.ReLU):
        super().__init__()
        # 校验卷积核尺寸, 并排序
        klist = sorted(k)
        assert len(klist) > 1, 'RepConv with a single branch is illegal'
        self.m = nn.ModuleList()
        for k in klist:
            # Identity
            if k == 0:
                assert c1 == c2, 'Failed to add the identity mapping branch'
                self.m.append(BatchNorm(c2, s=s))
            # nn.Conv2d + BatchNorm
            elif k > 0:
                assert k & 1, f'The convolution kernel size {k} must be odd'
                self.m.append(Conv(c1, c2, k=k, s=s, g=g, d=d, act=None, ctrpad=False))
            else:
                raise AssertionError(f'Wrong kernel size {k}')
        # Activation
        self.act = act() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.m(x) if self.deploy else sum_(tuple(m(x) for m in self.m)))

    @classmethod
    def reparam(cls, model: nn.Module):
        Conv.reparam(model)
        # 查询模型的所有子模型, 对 RepConv 进行合并
        for m in filter(lambda m: isinstance(m, cls) and not m.deploy, model.modules()):
            expp, cfg = m.m[-1].conv.weight, m.m[-1]._config
            conv = nn.Conv2d(**cfg, bias=True).to(expp)
            mlist, m.m = m.m, conv
            (c2, c1g, k, _), g = conv.weight.shape, conv.groups
            # nn.Conv2d 参数置零
            nn.init.constant_(conv.weight, 0)
            nn.init.constant_(conv.bias, 0)
            for branch in mlist:
                # BatchNorm
                if isinstance(branch, BatchNorm):
                    w, b = branch.unpack(detach=True)
                    conv.weight.data[..., k // 2, k // 2] += torch.eye(c1g).repeat(g, 1).to(expp) * w[:, None]
                # Conv
                else:
                    branch = branch.conv
                    p = (k - branch.kernel_size[0]) // 2
                    w, b = branch.weight.data, branch.bias.data
                    conv.weight.data += F.pad(w, (p,) * 4)
                conv.bias.data += b

然后设计一个集成模型进行验证:

  • reparam 函数是否改变了网络结构
  • 重参数化前后,模型的运算结果是否一致
  • 重参数化后,模型的推理速度是否有所提升
class timer:

    def __init__(self, repeat: int = 1, avg: bool = True):
        self.repeat = max(1, int(repeat) if isinstance(repeat, float) else repeat)
        self.avg = avg

    def __call__(self, func):
        import time

        def handler(*args, **kwargs):
            t0 = time.time()
            for i in range(self.repeat): func(*args, **kwargs)
            cost = (time.time() - t0) * 1e3
            print('Cost:', cost / self.repeat if self.avg else cost)
            return func(*args, **kwargs)

        return handler


class RandomModel(nn.Sequential):

    def __init__(self, c1=3, c_=8):
        super().__init__(
            # 下采样卷积: 1×1, 3×3
            RepConv(c1, c_, k=(1, 3), s=2),
            # 膨胀卷积: 恒等映射, 3×3, 7×7
            RepConv(c_, c_, k=(0, 3, 7), g=1, d=3),
            # 深度可分离卷积: 恒等映射, 1×1, 3×3, 9×9
            RepConv(c_, c_, k=(0, 1, 3, 9), g=c_, d=1)
        )

    @timer(10)
    def forward(self, x):
        return super().forward(x).sum(dim=(0, 1))


if __name__ == '__main__':
    model = RandomModel().eval()
    print(model, '\n')

    with torch.no_grad():
        # 为 BatchNorm 初始化随机参数
        for m in filter(lambda m: isinstance(m, nn.BatchNorm2d), model.modules()):
            m.running_mean.data, m.running_var.data, \
                m.weight.data, m.bias.data = torch.rand([4, m.num_features])

        image = torch.rand([1, 3, 10, 10])
        # 使用训练结构进行测试
        print(model(image), '\n')

        # 调用 RepConv 的类方法, 合并分支
        RepConv.reparam(model)
        print(model, '\n')
        # 使用推理结构进行测试
        print(model(image), '\n')

合并分支之前的输出:

Cost: 3.5999536514282227
tensor([[3.9966, 3.9923, 3.6735, 3.2401, 3.7511],
        [3.0010, 4.5650, 4.4806, 3.0237, 3.9267],
        [2.6274, 3.2868, 3.8479, 2.7222, 3.1459],
        [3.1416, 3.7339, 3.9633, 3.6590, 4.0194],
        [2.7083, 3.8338, 3.7580, 3.3356, 3.2799]]) 

合并分支之后的输出:

Cost: 0.5004405975341797
tensor([[3.9966, 3.9923, 3.6735, 3.2401, 3.7511],
        [3.0010, 4.5650, 4.4806, 3.0237, 3.9267],
        [2.6274, 3.2868, 3.8479, 2.7222, 3.1459],
        [3.1416, 3.7339, 3.9633, 3.6590, 4.0194],
        [2.7083, 3.8338, 3.7580, 3.3356, 3.2799]]) 

合并分支之前的模型: 

RandomModel(
  (0): RepConv(
    (rep): ModuleList(
      (0): Conv(
        3, 8, kernel_size=(1, 1), stride=(2, 2), bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
      (1): Conv(
        3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
    )
    (act): ReLU()
  )
  (1): RepConv(
    (rep): ModuleList(
      (0): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
      (2): Conv(
        8, 8, kernel_size=(7, 7), stride=(1, 1), padding=(9, 9), dilation=(3, 3), bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
    )
    (act): ReLU()
  )
  (2): RepConv(
    (rep): ModuleList(
      (0): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv(
        8, 8, kernel_size=(1, 1), stride=(1, 1), groups=8, bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
      (2): Conv(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
      (3): Conv(
        8, 8, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), groups=8, bias=False
        (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Identity()
      )
    )
    (act): ReLU()
  )

合并分支之后的模型:

RandomModel(
  (0): RepConv(
    (rep): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (act): ReLU()
  )
  (1): RepConv(
    (rep): Conv2d(8, 8, kernel_size=(7, 7), stride=(1, 1), padding=(9, 9), dilation=(3, 3))
    (act): ReLU()
  )
  (2): RepConv(
    (rep): Conv2d(8, 8, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), groups=8)
    (act): ReLU()
  )

  • 12
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

荷碧TongZJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值