RepLKNet-YoloV5

背景简介

自从VGG提出以后,各种CNN网络层出不穷,但是他们都遵循了VGG的设计思想,通过多个小卷积核叠加来得到大的感受野同时保证较少的参数量(2个3x3的卷积核感受野和5x5的卷积核感受野相同,但是参数18<25)。随着ViT逐渐在各类视觉任务中拿到SOTA的表现,CNN似乎有点后继无力。RepLKNet打破了这种现象,旷世科技研究员发现,kernel szie对模型性能尤其是分割、检测等下游任务至关重要,提出在CNN网络中大量采用超大卷积核的模型并结合结构重参数化、depthwise等,RepLKNet在各类视觉任务中获得了SOTA表现。

本文不做对网络具体细节的探究,仅做工程应用的整合。

RepLKNet-YoloV5

仍然以目标检测经典模型yolov5为例,对源代码做如下修改

环境配置

首先是对环境的配置,论文中对大kernel size的优化已经集成到了旷世的MegEngine框架中,但如果想在其他框架下使用,则需要编译相关源码,pytorch的使用如下:

git clone https://github.com/MegEngine/cutlass

cd cutlass-master/examples/19_large_depthwise_conv2d_torch_extension

./setup.py install --user   # 此处报错则检查本机的CUDA_HOME

conda activate envs  # 激活使用的虚拟环境

export PYTHONPATH=/本机git clone下来的路径/cutlass-master/examples/19_large_depthwise_conv2d_torch_extension

export LARGE_KERNEL_CONV_IMPL=/本机git clone下来的路径/cutlass-master/examples/19_large_depthwise_conv2d_torch_extension
common.py
# 增加如下代码
#-------------------------------------RepLKNet------------------------------------------------------
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath
import sys
import os

def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
    if type(kernel_size) is int:
        use_large_impl = kernel_size > 5
    else:
        assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
        use_large_impl = kernel_size[0] > 5
    has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
    if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
        sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
        #   Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
        #   export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
        # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
        # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
        from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
        return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
    else:
        return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups, bias=bias)

use_sync_bn = False

def enable_sync_bn():
    global use_sync_bn
    use_sync_bn = True

def get_bn(channels):
    if use_sync_bn:
        return nn.SyncBatchNorm(channels)
    else:
        return nn.BatchNorm2d(channels)

def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
    if padding is None:
        padding = kernel_size // 2
    result = nn.Sequential()
    result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                         stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
    result.add_module('bn', get_bn(out_channels))
    return result

def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
    if padding is None:
        padding = kernel_size // 2
    result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                         stride=stride, padding=padding, groups=groups, dilation=dilation)
    result.add_module('nonlinear', nn.ReLU())
    return result

def fuse_bn(conv, bn):
    kernel = conv.weight
    running_mean = bn.running_mean
    running_var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std

class ReparamLargeKernelConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, groups,
                 small_kernel,
                 small_kernel_merged=False):
        super(ReparamLargeKernelConv, self).__init__()
        self.kernel_size = kernel_size
        self.small_kernel = small_kernel
        # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
        padding = kernel_size // 2
        if small_kernel_merged:
            self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                          stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
        else:
            self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                      stride=stride, padding=padding, dilation=1, groups=groups)
            if small_kernel is not None:
                assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
                self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
                                             stride=stride, padding=small_kernel//2, groups=groups, dilation=1)

    def forward(self, inputs):
        if hasattr(self, 'lkb_reparam'):
            out = self.lkb_reparam(inputs)
        else:
            out = self.lkb_origin(inputs)
            if hasattr(self, 'small_conv'):
                out += self.small_conv(inputs)
        return out

    def get_equivalent_kernel_bias(self):
        eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
        if hasattr(self, 'small_conv'):
            small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
            eq_b += small_b
            #   add to the central part
            eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
        return eq_k, eq_b

    def merge_kernel(self):
        eq_k, eq_b = self.get_equivalent_kernel_bias()
        self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
                                     out_channels=self.lkb_origin.conv.out_channels,
                                     kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
                                     padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
                                     groups=self.lkb_origin.conv.groups, bias=True)
        self.lkb_reparam.weight.data = eq_k
        self.lkb_reparam.bias.data = eq_b
        self.__delattr__('lkb_origin')
        if hasattr(self, 'small_conv'):
            self.__delattr__('small_conv')


class ConvFFN(nn.Module):

    def __init__(self, in_channels, internal_channels, out_channels, drop_path):
        super().__init__()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.preffn_bn = get_bn(in_channels)
        self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
        self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
        self.nonlinear = nn.GELU()

    def forward(self, x):
        out = self.preffn_bn(x)
        out = self.pw1(out)
        out = self.nonlinear(out)
        out = self.pw2(out)
        return x + self.drop_path(out)


class RepLKBlock(nn.Module):

    def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
        super().__init__()
        self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
        self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
        self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
                                                  stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
        self.lk_nonlinear = nn.ReLU()
        self.prelkb_bn = get_bn(in_channels)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # print('drop path:', self.drop_path)

    def forward(self, x):
        out = self.prelkb_bn(x)
        out = self.pw1(out)
        out = self.large_kernel(out)
        out = self.lk_nonlinear(out)
        out = self.pw2(out)
        return x + self.drop_path(out)


class RepLKNetStage(nn.Module):

    def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
                 small_kernel, dw_ratio=1, ffn_ratio=4,
                 use_checkpoint=False,      # train with torch.utils.checkpoint to save memory
                 small_kernel_merged=False,
                 norm_intermediate_features=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        blks = []
        for i in range(num_blocks):
            block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
            #   Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
            replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
                                     small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
            convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
                                    drop_path=block_drop_path)
            blks.append(replk_block)
            blks.append(convffn_block)
        self.blocks = nn.ModuleList(blks)
        if norm_intermediate_features:
            self.norm = get_bn(channels)    #   Only use this with RepLKNet-XL on downstream tasks
        else:
            self.norm = nn.Identity()

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)   # Save training memory
            else:
                x = blk(x)
        return x

class RepLKNet_Stem(nn.Module):
    def __init__(self, in_channel, channels):
        super(RepLKNet_Stem, self).__init__()
        base_width = channels[0]
        self.use_checkpoint = True  
        self.stem = nn.ModuleList([
            conv_bn_relu(in_channels=in_channel, out_channels=base_width, kernel_size=3, stride=2, padding=1,
                         groups=1),
            conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1,
                         groups=base_width),
            conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
            conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1,
                         groups=base_width)])
    def forward(self, x):
        x = self.stem[0](x)
        for stem_layer in self.stem[1:]:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(stem_layer, x)   # save memory
            else:
                x = stem_layer(x)
        return x

class RepLKNet_stage1(nn.Module):
    def __init__(self, channels, large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], drop_path_rate=0.3,
                 small_kernel=5, dw_ratio=1, ffn_ratio=4):
        super(RepLKNet_stage1, self).__init__()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
        self.use_checkpoint = True
        self.layer = RepLKNetStage(channels=channels[0], num_blocks=layers[0],
                              stage_lk_size=large_kernel_sizes[0],
                              drop_path=dpr[sum(layers[:0]):sum(layers[:0 + 1])],
                              small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
                              use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
                              norm_intermediate_features=False)
        self.transition = nn.Sequential(
            conv_bn_relu(channels[0], channels[0 + 1], 1, 1, 0, groups=1),
            conv_bn_relu(channels[0 + 1], channels[0 + 1], 3, stride=2, padding=1,
                         groups=channels[0 + 1]))

    def forward(self, x):
        x = self.layer(x)
        x = self.transition(x)
        return x

class RepLKNet_stage2(nn.Module):
    def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
                 small_kernel=5, dw_ratio=1, ffn_ratio=4):
        super(RepLKNet_stage2, self).__init__()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
        self.use_checkpoint = True
        self.layer = RepLKNetStage(channels=channels[1], num_blocks=layers[1],
                                   stage_lk_size=large_kernel_sizes[1],
                                   drop_path=dpr[sum(layers[:1]):sum(layers[:1 + 1])],
                                   small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
                                   use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
                                   norm_intermediate_features=False)
        self.transition = nn.Sequential(
            conv_bn_relu(channels[1], channels[1 + 1], 1, 1, 0, groups=1),
            conv_bn_relu(channels[1 + 1], channels[1 + 1], 3, stride=2, padding=1,
                         groups=channels[1 + 1]))

    def forward(self, x):
        x = self.layer(x)
        x = self.transition(x)
        return x

class RepLKNet_stage3(nn.Module):
    def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
                 small_kernel=5, dw_ratio=1, ffn_ratio=4):
        super(RepLKNet_stage3, self).__init__()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
        self.use_checkpoint = True
        self.layer = RepLKNetStage(channels=channels[2], num_blocks=layers[2],
                                   stage_lk_size=large_kernel_sizes[2],
                                   drop_path=dpr[sum(layers[:2]):sum(layers[:2 + 1])],
                                   small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
                                   use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
                                   norm_intermediate_features=False)
        self.transition = nn.Sequential(
            conv_bn_relu(channels[2], channels[2 + 1], 1, 1, 0, groups=1),
            conv_bn_relu(channels[2 + 1], channels[2 + 1], 3, stride=2, padding=1,
                         groups=channels[2 + 1]))

    def forward(self, x):
        x = self.layer(x)
        x = self.transition(x)
        return x

class RepLKNet_stage4(nn.Module):
    def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
                 small_kernel=5, dw_ratio=1, ffn_ratio=4):
        super(RepLKNet_stage4, self).__init__()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
        self.use_checkpoint = True
        self.layer = RepLKNetStage(channels=channels[3], num_blocks=layers[3],
                                   stage_lk_size=large_kernel_sizes[3],
                                   drop_path=dpr[sum(layers[:3]):sum(layers[:3 + 1])],
                                   small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
                                   use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
                                   norm_intermediate_features=False)

    def forward(self, x):
        x = self.layer(x)
        return x


RepLKNet_Stem, RepLKNet_stage1, RepLKNet_stage2, RepLKNet_stage3, RepLKNet_stage4四个模型块中的self.use_checkpoint = True可以理解为用训练的时间换取内存空间,详细使用方法可参考torch.utils.checkpoint的使用

yolo.py
# 修改parse_model函数
def parse_model(d, ch):  # model_dict, input_channels(3)
    LOGGER.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            try:
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
            except:
                pass

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
            c1, c2 = ch[f], args[0]   
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]

            if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum([ch[x] for x in f])
        elif m is Detect:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
# 添加加该部分代码
#---------------------------------------------            
        elif m in [RepLKNet_Stem, RepLKNet_stage1, RepLKNet_stage2, RepLKNet_stage3, RepLKNet_stage4]:
            c2 = args[0]
            args = args[1:]
#----------------------------------------------
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum([x.numel() for x in m_.parameters()])  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n_, np, t, args))  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)
yolov5_RepLKNet.yaml

# Parameters
# 以RepLKNet31B为例
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 backbone
backbone:
  [[-1, 1, RepLKNet_Stem, [128, 3, [128,256,512,1024]]],
   [-1, 1, RepLKNet_stage1, [256, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
   [-1, 1, RepLKNet_stage2, [512, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
   [-1, 1, RepLKNet_stage3, [1024, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
   [-1, 1, RepLKNet_stage4, [1024, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
  ]

# YOLOv5 head
head:
  [[-1, 1, Conv, [1024, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 2], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [1024, False]],  # 13

   [-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 1], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [512, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 9], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [1024, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [1024, 3, 2]],
   [[-1, 5], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [2048, False]],  # 23 (P5/32-large)

   [[12, 15, 18], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

最后看一下相关参数对比吧

modelSummaryparametersGFLOPs
yolov5s283 layers708242116.4
yolov5l499 layers46669045114.4
RepLkNet-yolov5854 layers90261291275.0

完整工程见github:

https://github.com/OutBreak-hui/YoloV5-Flexible-and-Inference

在这里插入图片描述

  • 5
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 16
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值