背景简介
自从VGG提出以后,各种CNN网络层出不穷,但是他们都遵循了VGG的设计思想,通过多个小卷积核叠加来得到大的感受野同时保证较少的参数量(2个3x3的卷积核感受野和5x5的卷积核感受野相同,但是参数18<25)。随着ViT逐渐在各类视觉任务中拿到SOTA的表现,CNN似乎有点后继无力。RepLKNet打破了这种现象,旷世科技研究员发现,kernel szie对模型性能尤其是分割、检测等下游任务至关重要,提出在CNN网络中大量采用超大卷积核的模型并结合结构重参数化、depthwise等,RepLKNet在各类视觉任务中获得了SOTA表现。
本文不做对网络具体细节的探究,仅做工程应用的整合。
RepLKNet-YoloV5
仍然以目标检测经典模型yolov5为例,对源代码做如下修改
环境配置
首先是对环境的配置,论文中对大kernel size的优化已经集成到了旷世的MegEngine框架中,但如果想在其他框架下使用,则需要编译相关源码,pytorch的使用如下:
git clone https://github.com/MegEngine/cutlass
cd cutlass-master/examples/19_large_depthwise_conv2d_torch_extension
./setup.py install --user # 此处报错则检查本机的CUDA_HOME
conda activate envs # 激活使用的虚拟环境
export PYTHONPATH=/本机git clone下来的路径/cutlass-master/examples/19_large_depthwise_conv2d_torch_extension
export LARGE_KERNEL_CONV_IMPL=/本机git clone下来的路径/cutlass-master/examples/19_large_depthwise_conv2d_torch_extension
common.py
# 增加如下代码
#-------------------------------------RepLKNet------------------------------------------------------
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath
import sys
import os
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
if type(kernel_size) is int:
use_large_impl = kernel_size > 5
else:
assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
use_large_impl = kernel_size[0] > 5
has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
# Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
# export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
# TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
# Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
else:
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
use_sync_bn = False
def enable_sync_bn():
global use_sync_bn
use_sync_bn = True
def get_bn(channels):
if use_sync_bn:
return nn.SyncBatchNorm(channels)
else:
return nn.BatchNorm2d(channels)
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
if padding is None:
padding = kernel_size // 2
result = nn.Sequential()
result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
result.add_module('bn', get_bn(out_channels))
return result
def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
if padding is None:
padding = kernel_size // 2
result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=groups, dilation=dilation)
result.add_module('nonlinear', nn.ReLU())
return result
def fuse_bn(conv, bn):
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class ReparamLargeKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, groups,
small_kernel,
small_kernel_merged=False):
super(ReparamLargeKernelConv, self).__init__()
self.kernel_size = kernel_size
self.small_kernel = small_kernel
# We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
padding = kernel_size // 2
if small_kernel_merged:
self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
else:
self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=1, groups=groups)
if small_kernel is not None:
assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
stride=stride, padding=small_kernel//2, groups=groups, dilation=1)
def forward(self, inputs):
if hasattr(self, 'lkb_reparam'):
out = self.lkb_reparam(inputs)
else:
out = self.lkb_origin(inputs)
if hasattr(self, 'small_conv'):
out += self.small_conv(inputs)
return out
def get_equivalent_kernel_bias(self):
eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
if hasattr(self, 'small_conv'):
small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b
# add to the central part
eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
return eq_k, eq_b
def merge_kernel(self):
eq_k, eq_b = self.get_equivalent_kernel_bias()
self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
out_channels=self.lkb_origin.conv.out_channels,
kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
groups=self.lkb_origin.conv.groups, bias=True)
self.lkb_reparam.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b
self.__delattr__('lkb_origin')
if hasattr(self, 'small_conv'):
self.__delattr__('small_conv')
class ConvFFN(nn.Module):
def __init__(self, in_channels, internal_channels, out_channels, drop_path):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.preffn_bn = get_bn(in_channels)
self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
self.nonlinear = nn.GELU()
def forward(self, x):
out = self.preffn_bn(x)
out = self.pw1(out)
out = self.nonlinear(out)
out = self.pw2(out)
return x + self.drop_path(out)
class RepLKBlock(nn.Module):
def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
super().__init__()
self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
self.lk_nonlinear = nn.ReLU()
self.prelkb_bn = get_bn(in_channels)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# print('drop path:', self.drop_path)
def forward(self, x):
out = self.prelkb_bn(x)
out = self.pw1(out)
out = self.large_kernel(out)
out = self.lk_nonlinear(out)
out = self.pw2(out)
return x + self.drop_path(out)
class RepLKNetStage(nn.Module):
def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
small_kernel, dw_ratio=1, ffn_ratio=4,
use_checkpoint=False, # train with torch.utils.checkpoint to save memory
small_kernel_merged=False,
norm_intermediate_features=False):
super().__init__()
self.use_checkpoint = use_checkpoint
blks = []
for i in range(num_blocks):
block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
# Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
drop_path=block_drop_path)
blks.append(replk_block)
blks.append(convffn_block)
self.blocks = nn.ModuleList(blks)
if norm_intermediate_features:
self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
else:
self.norm = nn.Identity()
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x) # Save training memory
else:
x = blk(x)
return x
class RepLKNet_Stem(nn.Module):
def __init__(self, in_channel, channels):
super(RepLKNet_Stem, self).__init__()
base_width = channels[0]
self.use_checkpoint = True
self.stem = nn.ModuleList([
conv_bn_relu(in_channels=in_channel, out_channels=base_width, kernel_size=3, stride=2, padding=1,
groups=1),
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1,
groups=base_width),
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1,
groups=base_width)])
def forward(self, x):
x = self.stem[0](x)
for stem_layer in self.stem[1:]:
if self.use_checkpoint:
x = checkpoint.checkpoint(stem_layer, x) # save memory
else:
x = stem_layer(x)
return x
class RepLKNet_stage1(nn.Module):
def __init__(self, channels, large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], drop_path_rate=0.3,
small_kernel=5, dw_ratio=1, ffn_ratio=4):
super(RepLKNet_stage1, self).__init__()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
self.use_checkpoint = True
self.layer = RepLKNetStage(channels=channels[0], num_blocks=layers[0],
stage_lk_size=large_kernel_sizes[0],
drop_path=dpr[sum(layers[:0]):sum(layers[:0 + 1])],
small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
norm_intermediate_features=False)
self.transition = nn.Sequential(
conv_bn_relu(channels[0], channels[0 + 1], 1, 1, 0, groups=1),
conv_bn_relu(channels[0 + 1], channels[0 + 1], 3, stride=2, padding=1,
groups=channels[0 + 1]))
def forward(self, x):
x = self.layer(x)
x = self.transition(x)
return x
class RepLKNet_stage2(nn.Module):
def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
small_kernel=5, dw_ratio=1, ffn_ratio=4):
super(RepLKNet_stage2, self).__init__()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
self.use_checkpoint = True
self.layer = RepLKNetStage(channels=channels[1], num_blocks=layers[1],
stage_lk_size=large_kernel_sizes[1],
drop_path=dpr[sum(layers[:1]):sum(layers[:1 + 1])],
small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
norm_intermediate_features=False)
self.transition = nn.Sequential(
conv_bn_relu(channels[1], channels[1 + 1], 1, 1, 0, groups=1),
conv_bn_relu(channels[1 + 1], channels[1 + 1], 3, stride=2, padding=1,
groups=channels[1 + 1]))
def forward(self, x):
x = self.layer(x)
x = self.transition(x)
return x
class RepLKNet_stage3(nn.Module):
def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
small_kernel=5, dw_ratio=1, ffn_ratio=4):
super(RepLKNet_stage3, self).__init__()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
self.use_checkpoint = True
self.layer = RepLKNetStage(channels=channels[2], num_blocks=layers[2],
stage_lk_size=large_kernel_sizes[2],
drop_path=dpr[sum(layers[:2]):sum(layers[:2 + 1])],
small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
norm_intermediate_features=False)
self.transition = nn.Sequential(
conv_bn_relu(channels[2], channels[2 + 1], 1, 1, 0, groups=1),
conv_bn_relu(channels[2 + 1], channels[2 + 1], 3, stride=2, padding=1,
groups=channels[2 + 1]))
def forward(self, x):
x = self.layer(x)
x = self.transition(x)
return x
class RepLKNet_stage4(nn.Module):
def __init__(self, channels, large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2], drop_path_rate=0.3,
small_kernel=5, dw_ratio=1, ffn_ratio=4):
super(RepLKNet_stage4, self).__init__()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
self.use_checkpoint = True
self.layer = RepLKNetStage(channels=channels[3], num_blocks=layers[3],
stage_lk_size=large_kernel_sizes[3],
drop_path=dpr[sum(layers[:3]):sum(layers[:3 + 1])],
small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
use_checkpoint=self.use_checkpoint, small_kernel_merged=False,
norm_intermediate_features=False)
def forward(self, x):
x = self.layer(x)
return x
RepLKNet_Stem, RepLKNet_stage1, RepLKNet_stage2, RepLKNet_stage3, RepLKNet_stage4四个模型块中的self.use_checkpoint = True可以理解为用训练的时间换取内存空间,详细使用方法可参考torch.utils.checkpoint的使用
yolo.py
# 修改parse_model函数
def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
args[j] = eval(a) if isinstance(a, str) else a # eval strings
except:
pass
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum([ch[x] for x in f])
elif m is Detect:
args.append([ch[x] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
elif m is Contract:
c2 = ch[f] * args[0] ** 2
elif m is Expand:
c2 = ch[f] // args[0] ** 2
# 添加加该部分代码
#---------------------------------------------
elif m in [RepLKNet_Stem, RepLKNet_stage1, RepLKNet_stage2, RepLKNet_stage3, RepLKNet_stage4]:
c2 = args[0]
args = args[1:]
#----------------------------------------------
else:
c2 = ch[f]
m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
yolov5_RepLKNet.yaml
# Parameters
# 以RepLKNet31B为例
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 backbone
backbone:
[[-1, 1, RepLKNet_Stem, [128, 3, [128,256,512,1024]]],
[-1, 1, RepLKNet_stage1, [256, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
[-1, 1, RepLKNet_stage2, [512, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
[-1, 1, RepLKNet_stage3, [1024, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
[-1, 1, RepLKNet_stage4, [1024, [128,256,512,1024], [31,29,27,13], [2,2,18,2], 0.3, 5, 1, 4]],
]
# YOLOv5 head
head:
[[-1, 1, Conv, [1024, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 2], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [1024, False]], # 13
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 1], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [512, False]], # 17 (P3/8-small)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 9], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [1024, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [1024, 3, 2]],
[[-1, 5], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [2048, False]], # 23 (P5/32-large)
[[12, 15, 18], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
最后看一下相关参数对比吧
model | Summary | parameters | GFLOPs |
---|---|---|---|
yolov5s | 283 layers | 7082421 | 16.4 |
yolov5l | 499 layers | 46669045 | 114.4 |
RepLkNet-yolov5 | 854 layers | 90261291 | 275.0 |
完整工程见github:
https://github.com/OutBreak-hui/YoloV5-Flexible-and-Inference