SegNeXt是一个简单的用于语义分割的卷积网络架构,通过对传统卷积结构的改进,在一定的参数规模下超越了transformer模型的性能,同等参数规模下在 ADE20K, Cityscapes,COCO-Stuff, Pascal VOC, Pascal Context, 和 iSAID数据集上的miou比transformer模型高2个点以上。其优越之处在对编码器(backbone)的的改进,将transformer中模型的一些特殊结构引入了传统卷积中,并提出了MSCAAttention结构,在语义分割中的空间attenion中占据一定优势。
论文地址:https://arxiv.org/pdf/2209.08575.pdf
项目地址:https://github.com/Visual-Attention-Network/SegNeXt
1、背景知识
语义分割与分类模型
在语义分割模型[80,96,5]中,其编码器大多在ImageNet数据集上进行预训练(分类训练)。为了捕获高级语义,通常需要一个解码器,它被应用到编码器(backbone)上。
transformer与语义分割
1、SETR [96] and SegFormer为transformer在语义分割中的使用作出改进。
1.1 优秀语义分割模型的特征
1、使用强骨干网络作为编码器(而解码器相对较弱) 基于transformer的语义分割模型性能强劲正在于backbone的强大;此外,很多cnn语义分割模型都是使用在imagenet预训练过的网络做backbone;
2、多尺度的信息交互 与识别主体内容的图像分类不同,语义分割的目标尺度存在较大差异,需要融合不同尺度的感受野
3、空间attention 空间attention可以是模型通过语义区域内的优先级进行区域分割
4、计算量低 这在处理高分辨率的图像中极为重要;此外低flop模型有利于轻量化部署。
1.2 语义分割中的常见trick
通用性技术:
1、使用更强大的backbone做解码器;
2、图像增强技术
定制化技术:
1、使用多尺度接收域(多尺度输入)
2、收集多尺度语义(如hrnet一样实现多尺度并行支路;PSPNet)
3、扩大接收域(孔洞卷积,ASPP )
4、加强边缘特征
4、捕捉全局上下文
1.3 语义分割下的attention
注意机制是一种自适应的选择过程,旨在使网络关注的重要部分。一般来说,语义分割[25]可分为spatial attention和channel attention两类。最近流行的transformer模型,在语义分割上通常忽略了channel attention的适应性。
号外:Visual attention network (VAN)提出了使用大核attention机制来构建通道注意和空间注意,并取得了较好的性能。
2、方法实现
2.1 编码器设计
设计了一种新的多尺度卷积注意(MSCA)模块。如图2 (a)所示,MSCA包含三个部分:深度卷积聚合局部信息,多分支深度条卷积捕获多尺度上下文,以及1×1卷积建模不同通道之间的关系。
将1×1卷积的输出直接作为注意权值来重新称重MSCA的输入。在数学上,我们的MSCA可以写成以下形式。其中,
F
F
F表示输入特征;
S
c
a
l
e
i
,
i
∈
0
,
1
,
2
,
3
Scale_i,i∈{0,1,2,3}
Scalei,i∈0,1,2,3表示图2b中的支路,其中支路0是跳跃连接表示,其他支路使用深度可分卷积来近似大卷积核(分别近似7x7、11x11、21x21)。采用深度可分卷积有两个原因,1:可以在减少参数量的同时近似大卷积核,2:在实际分割中存在条状物,深度可分卷积可作为对普通卷积的补充。
堆叠建块所构成的卷积编码器,名为MSCAN。对于MSCAN,采用了一个共同的层次结构,它包含四个空间分辨率降低的阶段:
H
4
×
W
4
\frac{H}{4} \times \frac{W}{4}
4H×4W、
H
8
×
W
8
\frac{H}{8} \times \frac{W}{8}
8H×8W、
H
16
×
W
16
\frac{H}{16} \times \frac{W}{16}
16H×16W和
H
32
×
W
32
\frac{H}{32} \times \frac{W}{32}
32H×32W。H和W分别为输入图像的高度和宽度
。每个阶段都包含一个下采样块和如上所述的构建块的堆栈。下采样块使用内核大小为3×3步幅2的卷积实现,然后是批处理归一化层[35]。注意,在MSCAN的每个构建块中,使用批归一化而不是层归一化,因为我们发现批归一化在分割性能方面收益更多。
我们设计了四种不同大小的编码器模型,分别命名为MSCAN-T、MSCAN-S、MSCAN-B和MSCAN-L。相应的整体分割模型分别被称为SegNeXt-T、SegNeXt-S、SegNeXt-B、SegNeXt-L。详细的网络设置显示在表2中。
2.2 解码器设计
我们研究了三种简单的解码器结构,如图3所示。第一个在SegFormer[80]中采用,是纯基于MLP的结构。第二种方法大多采用了基于cnn的模型。在这种结构中,编码器的输出被直接用作解码器头的输入,如ASPP [5]、PSP [94]和DANet[19]。最后一个是在我们的SegNeXt中采用的结构。我们聚合了最后三个阶段的特性,并使用一个轻量级的Hamburger[21]来进一步建模全局上下文。结合我们强大的卷积编码器,我们发现使用轻量级解码器可以提高性能计算效率。
与segfrome的解码器聚合了从阶段1到阶段4的特征不同,我们的解码器只接收来自最后三个阶段的特征。这是因为我们的SegNeXt是基于卷积的。阶段1的特性包含太多的低级信息,损害了性能。此外,在阶段1上的操作也带来了巨大的计算开销。在我们的实验部分,我们将展示我们的卷积SegNeXt比最近最先进的基于transformer的SegFormer[80]和HRFormer[88]要好得多。
3、代码实现
3.1 整体配置文件
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='MSCAN',
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
embed_dims=[32, 64, 160, 256],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[64, 160, 256],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=16,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
其中编码器的完整实现代码为:https://github.com/open-mmlab/mmsegmentation/blob/70477d21adfb4caaf2e35c8e6d88b9ee740486d6/mmseg/models/backbones/mscan.py
其中解码器的完整实现代码为: https://github.com/open-mmlab/mmsegmentation/blob/70477d21adfb4caaf2e35c8e6d88b9ee740486d6/mmseg/models/decode_heads/ham_head.py
这里主要分析编码器的实现
3.2 MSCA实现代码
MSCA的可分离卷积是由Conv2d所实现的(其特点为,padding也是对应方向的,输入输出channel一致,groups与channel一致),输入先由self.conv0做卷积,然后再交给4个支路(三个深度可分离卷积+一个直通连接);最后的self.conv3是一个正常的conv1x1卷积,用于通道信息融合。
class AttentionModule(BaseModule):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
attn = self.conv3(attn)
return attn * u
同时,通过对比表8中有无MSCA模块的精度变化(无该模块miou会掉一个点);再对比表9中与Segformer的对比,SegNeXT的miou基本上要强3个点以上<没有MSCA模块,SegNeXT也很强劲
>,可见MSCA模块并不是SegNeXT优秀的核心原因。
3.3 MSCAN内部模块
StemConv
用于对输入数据的规整化,该操作很少被纳入backbone的范围(如resnet存在stem模块、hrdnet也存在stem模块)。MSCAN中Stem的proj平平无奇,而对输出数据的reshape操作却与众不同的进行的flatten操作,其将W与H进行了压扁操作。
class StemConv(BaseModule):
def __init__(self, in_channels, out_channels, norm_cfg=dict(type='SyncBN', requires_grad=True)):
super(StemConv, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2,
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels // 2)[1],
nn.GELU(),
nn.Conv2d(out_channels // 2, out_channels,
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels)[1],
)
def forward(self, x):#假如输入x: batchsize,channel,width,heigh
x = self.proj(x)
_, _, H, W = x.size()
x = x.flatten(2).transpose(1, 2)#x.flatten(2) 表示将第2个维度后的数据都进行压扁
return x, H, W #此时输出x: batchsize,width*heigh,channel
Mlp
这是一个由conv1x1所构成的的全连接层,其中的DWConv是一个普通的分组卷积(kernel size 3x3)。
该module在block中的示意如下
class Mlp(BaseModule):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
SpatialAttention
该模块没有在论文中细提,其为MSCAN中的一个attention,与图2中模块的对应关系如下,图中红色圈的部分既为SpatialAttention。该代码是正常的卷积模组。
class SpatialAttention(BaseModule):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = AttentionModule(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
Block实现
Block也就是上图MSCAN图a所示意的部分,与示意图不同的是Attention与FFN的输出在与跳跃连接作加时都进行了droppath操作。且,其中的FFN就是本节中所描述的MLP模块。该代码是正常的卷积模组。
class Block(BaseModule):
def __init__(self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.attn = SpatialAttention(dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.attn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x)))
x = x.view(B, C, N).permute(0, 2, 1)
return x
OverlapPatchEmbed
该模块与transformer中的PatchEmbed类似,是利用kernel_size=patch_size的conx对局部进行特征编码,可以看到其输出x的shape为B,C,W/4*H/4。该模块既是用作下采样。
class OverlapPatchEmbed(BaseModule):
""" Image to Patch Embedding
"""
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
patch_size = to_2tuple(patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
3.4 backbone的具体实现
backbone完整实现代码
完整实现代码如下,其中的mlp_ratios=[4, 4, 4, 4],让mlp层中的参数剧增,(512x4) x (512x4)的全连接,参数量几乎爆表。这预计会给SegNeXT在小数据集上的使用带来极大的困难。
@ BACKBONES.register_module()
class MSCAN(BaseModule):
def __init__(self,
in_chans=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super(MSCAN, self).__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths))] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([Block(dim=embed_dims[i], mlp_ratio=mlp_ratios[i],
drop=drop_rate, drop_path=dpr[cur + j],
norm_cfg=norm_cfg)
for j in range(depths[i])])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
def init_weights(self):
print('init cfg', self.init_cfg)
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super(MSCAN, self).init_weights()
def forward(self, x):
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
stages分析
backbone构建stages的代码如下,可以看到stage由patch_embed + n个block + LayerNorm所实现(patch_embed为StemConv或OverlapPatchEmbed)。针对于不同的stage,embed_dims,n_block都是有所不同的。其中,第一个stage是由StemConv所构成的。
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([Block(dim=embed_dims[i], mlp_ratio=mlp_ratios[i],
drop=drop_rate, drop_path=dpr[cur + j],
norm_cfg=norm_cfg)
for j in range(depths[i])])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
forword分析
backbone中的forword代码如下,可以看到数据的shape是在一直变化的。SegNeXT虽然没有使用transformer,但是却充分利用了PatchEmbed的特性。
def forward(self, x):
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
4、实验效果
4.1 ImageNet对比
ImageNet是一个著名的图像分类数据集,包含了1000种分类数据。SegNeXT在ImageNet进行预训练,以用于后续的语义分割。分类预训练效果如下:
4.2 iSAID对比
iSAID [76]是一个大规模的航空图像分割基准,它包括15个前景类和一个背景类。它的训练、验证和测试集分别涉及1411,458,937张图像。
4.3 Pascal VOC对比
Pascal VOC [18]涉及20个前景类和一个背景类。经过增强后,它分别有10582/1449/1456张图像用于训练、验证和测试。Pascal上下文[58]包含59个前景类和一个背景类。训练集和验证集分别包含4996张图像和5104张图像。
4.4 其他数据集
ADE20K [98]是一个具有挑战性的数据集,它包含150个语义类。它由训练、验证和测试集中的20210/3000/3352张图像组成。
Cityscapes[13]主要关注城市场景,包含5000张高分辨率图像,共19个类别。分别有2975/500/1525张图像用于训练、验证和测试。
COCO-Stuff [3]也是一个具有挑战性的基准测试,它总共包含172个语义类别和164k张图像。
4.5 实时方法对比
除了最先进的性能外,我们的方法也适用于实时部署。即使没有任何特定的软件或硬件加速,SegNeXt-T在处理大小为768×1,536的图像时,也可以使用一个3090 RTX的GPU来实现每秒25帧(FPS)。如标签页中所示。11,我们的方法在Cityscapes测试集上的实时分割设置了新的最先进的结果。
4.6 实验细节
基于timm库和mmseg框架机械能图像疯了和分割,所使用的模型均在IamgeNet-1k数据集上进行预训练。使用acc作为分类指标,使用mIoU作为分割指标。 所有模型均在8台RTX 3090显卡上训练。
分割实验中,采用了一些常见的数据增强方法,包括随机水平翻转、随机尺度化(从0.5到2)和随机裁剪。为Cityscapes数据集的批处理大小设置为8,为所有其他数据集设置为16。
应用AdamW [54]来训练模型。
初始学习率设置为0.00006,并采用多阶段学习率衰减策略(poly-learning)。
对ADE20K、城市景观和iSAID数据集训练模型160K迭代,对COCO-Stuff、Pascal VOC和Pascal VOC数据集训练模型80K迭代。
在测试过程中,同时使用了单尺度(SS)和多尺度(MS)翻转测试策略来进行公平的比较。
5、 消融实验
5.1 MSCA模块设计
在ImageNet和ADE20K数据集上进行了MSCA设计的消融研究。K×K分支包含一个深度的1×K卷积和一个K×1深度卷积。1×1 conv是指通道混合操作。注意是指元素的点乘操作,它使网络获得了自适应能力。结果如表6.我们可以发现,每一个部分都有助于最终的性能。
MSCA在分割中的重要性对比,我们将MSCA中的多个分支替换为具有大核的单一卷积。如表3和表8,我们可以观察到,使用MSCA模块的网络性能更为强劲,这说明聚合多尺度特征在编码器中是语义分割的关键。
5.2 全局上下文解码
解码器在整合分割模型中多尺度特征的全局上下文中起着重要的作用。在这里,我们研究了不同的全局上下文模块对解码器的影响。正如之前的大多数工作[75,19]所示,基于注意力的解码器对cnn中金字塔结构[94,5]具有更好的性能,因此我们只使用基于注意力的解码器来显示结果。具体来说,我们展示了4种不同类型的基于注意的解码器的结果,包括O(n 2)复杂度的非局部(NL)注意[75]和CCNet [34]、EMANet [40]和O (n)复杂度的HamNet [21]。如标签页中所示。Hamburger在复杂性和性能之间取得了最好的权衡。因此,我们在解码器中使用Hamburger[21]。
5.3 解码器结构
与图像分类不同,分割模型需要高分辨率的输出。我们采用了三种不同的解码器设计进行分割,如图3所示。相应的结果列在表7中,我们可以看到,SegNeXt ©取得了最好的性能,计算成本也很低。
6、特性分析
6.1 编码器分析
通过对SegNeXT代码的具体分析,可以发现SegNeXT的强劲,其实质在于编码器的强大(具体包括,将PatchEmbed引入传统卷积、将MLP引入传统卷积、提出MSCAN模组)其所提出的MSCAN在作为backbone时在限定参数量下就已经比一众transformer要强劲了。
6.2 解码器分析
SegNeXT将解码器简单的分成了3类,具体图下所示,并使用3类解码器进行了对比实验。其中可以发现其所使用的解码器并不具备压倒性优势,只要进行多尺度特征的输入,均能取得不错的效果
6.3 局限性分析
我们的模型也有其局限性,例如,将该方法扩展到具有100M+参数的大规模模型,以及在其他视觉或NLP任务上的性能并不具备先进性。
6.4 实用性对比
SegNeXT4个模型在ADE20K上的详细性能如下,可以看到其在内存上的消耗是巨大的。博主使用MSCAN-S简单进行训练测试,发现加载模型后显存消耗650MiB,对单个512x512 的图像进行forword后显存消耗为2315MiB,对512x512的图像进行训练,batchsize为8时显存为1200MiB。这种规模的内存消耗对于工程实践而言简直是灾难性的,从内存消耗与网络结构上看,应该是全连接所导致的。随后通过分析代码,发现就是MLP层中dwconv所导致的,其所包含的全连接有[64, 128, 256, 512]x4 ,其中最大的全连接层为2048*2048,简直恐怖如斯,强大是由代价的。
(mlp): Mlp(
(fc1): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
(dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)
(act): GELU()
(fc2): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
(drop): Dropout(p=0.0, inplace=False)
)
Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
---|---|---|---|---|---|---|---|---|---|
SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | config | model | log |
SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | config | model | log |
SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | config | model | log |
SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | config | model | log |
segformer在ADE20K上的性能如下,可以看出相比于segnext,segformer在训练于部署还是占据很强的优势的(segnext仅是在flops和推理时间上占据优势)
Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
---|---|---|---|---|---|---|---|---|---|
Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 38.17 | 37.85 | 38.97 | config | model | log |
Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 37.80 | 42.13 | 43.74 | config | model | log |
Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 26.80 | 46.80 | 48.12 | config | model | log |
Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 19.19 | 48.25 | 49.58 | config | model | log |
Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 14.54 | 49.09 | 50.72 | config | model | log |
Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | config | model | log |
Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 10.60 | 50.19 | 51.41 | config | model | log |