第一部分 概述
第二部分 SCNN网络
第三部分 SCNN网络关键代码分析
第一部分 概述
车道线检测是智能汽车辅助驾驶系统环境感知模块中非常重要的一个功能. 从基于传统方法的计算机视觉检测,到近年来逐渐成熟的基于深度学习的车道线检测, 各种检测思想不断被提出,.在上一篇博文中介绍了基于LaneNet_H-Net的车道线检测,但实际应用效果并不太好, 在车道线被其他车辆或障碍物大部分遮挡时,LaneNet并不能出色完成车道线检测任务. 因为现有的CNN网络随具有强大的特征提取能力,但对方向上的空间关系探索能力有限, 比如:车道线,电线杆这种细长的物体.
针对车道线这种特殊的物体检测,Spatial CNN(简称SCNN)被提出,使感知系统的车道线检测达到一个完美的阶段. 以及后来 百度Apollo 6.0 中车道线检测使用 Dark SCNN,在切片方式上有差别,也可理解为是SCNN的一个改进版.
SCNN对于长距离连续形状的物体, 及有着极强空间关系但因为遮挡二外观不连续的物体,例如车道线,具有优秀的检测效果.
第二部分 SCNN网络
熟悉CNN的朋友都知道,CNN的网络结构一般都是N个卷积层的堆叠, 然后一层(浅层)的输出作为下一层(深一层)输入,这样层接层(layer-by-layer)连接在一次, 复杂一点一些像FPN或彩虹网络那样自上而下或自下而上的深浅层特征融合(通道拼接,逐元素相乘,逐元素相加). 但都属于深浅层特征网络在垂直方向的特征(信息)提取或融合, 而在物体空间上或网络层水平(图像行,列)方向上, 各特征之间并没有充分的信息交流. 针对车道线这种在空间上呈现细长的, 需要强先验形状的物体, 空间上的关系显得格外重要.
一张来自网络或原作者的SCNN网络结果图,如下:
这张SCNN结构图基本已经表达了SCNN的基于特征图中片连片(slice-by-slice)卷积形式. 在进一步理解片连片之间,我用一个类似的例子,来帮助自己的理解:
假如某地举行一场徒步越野挑战赛(不分先后,到终点即可),有100人参与, 因为路途太长且沿途艰辛,体能原因大部分都会中途无奈退赛(相当于CNN中逐层丢失的弱语义特征),只有少部分人能到达终点(相当于图像中高语义特征). 二者100人中有10个人来自通过一个俱乐部且代表俱乐部参与比赛,互现认识且参赛图中可能会相互帮助,共同达到终点, 假设其他90人互不认识, 这10人之间的认识关系(强先验关系)可以理解为图像中一条细长的物体(强先验形状).
如果在按照标准的CNN网络规则, 所有人之间不能交流信息,不能相互协助,只能依靠个人体质往前冲, 强者胜出(强语义特征),弱者淘汰. 那么最后肯定只能剩下体质最强的参与者(真正意义的特征提取).
而SCNN网络,修改了规则, 参数者之间可以相互交流信息,可以相互帮助(相当于片连片的信息交流). 那么不管其他参与者结果如何, 来之同一个俱乐部的这10个人,可能会相互交流和帮助,最终全部达到终点(相当于挑战成功者之中仍保留了出发前的强先验认识关系).
例子可能不太恰当,当重点是强调了图像中特征之间这种横向的信息交流.
继续上面SCNN网络结构图, 分几步来分析:
一. 特征提取:
输入待检测图片,利用CNN网络层提取特征,这里可以是vgg等成熟的特征提取网络. 提取的特征一般是一个[B,C,H,W]形状的张量, 对于其中一张图片来说是一个[C,H,W]形状的张量. 如下图:
二. W和H维度方向上切片:
让图像特征中中,行和列之间传递信息. 即在[C,H,W]中的H维度方向上进行切片,分H片[C,W],片与片之间采用向下和向上两种信息交流方式; 在W维度方向上进行切片,分W片[C,H],片与片之间采用向左和向右两种信息交流方式. 至于C维方向,本身是在标准CNN网络上的特征提取,不需要做切片信息交流. 这样就有了向上,向下,先左,向右四种切片和信息交流方式;
三. 片与片信息交流过程
以H维度方向上进行切片为例,其他都类同. 首先取出特征中第0片[C,W]_0,作为新的[C,W]_0', 将[C,W]_0 经一个卷积卷积核为(1,ω)卷积后(也可以再通过一个ReLU激活函数),与原始第1个切片[C,W]_1相加后作为新的第1片[C,W]_1'. 依次有:
[C,W]_(i+1)' = [C,W]_i + ReLU( Cov( [C,W]_(i)' ) ) i 取值区间[0, H)
最终,将新生成的特征[C,H,W]' 作为新的特征,进行下一步处理.
四. 损失计算
经过步骤三的片与片信息交流后,输出一个同样维度[B,C,H,W]特征张量. 再经一个卷积和Softmax后,生成一个[B,n,h,w]维张量, 这里的n对应需要检测的车道线个数.
进一步,SCNN分两个分支计算损失: 实例分割预测(seg_pred) 和 是否存在车道线预测(exist_pred).
(1). 实例分割预测(seg_pred)
[B,n,h,w]维特征经插值上采样后和实例分割真值做交叉熵损失计算(CrossEntropyLoss),获取实例分割预测损失: seg_loss.
(2). 是否存在车道线预测(exist_pred)
[B,n,h,w]维特征张量经MaxPool后,经全连接层后,生成一个[m,n], 这里的n对应需要检测的车道线个数, 表示各个车道线存在与否的得分.
进一步,将输出的exist_pred[m,n]与是否有车道线真值做交叉熵损失计算, 得到: exist_loss.
第三部分 SCNN网络关键代码分析
这里仅分析SCNN网络的片与片信息交流部分代码. 在经上述第一步特征提取网络获取特征张量[B,C,H,W]后,在H维方向上分别做向上和向下切片和片间交流,然后,在此基础上进一步在W维方向上分别做向左和向右切片和片间交流. 一种代码实现如下:
def message_passing_forward(self, x):
#True:表在H维的垂直方向, False: 表在W维的水平方向
Vertical = [True, True, False, False]
#True or False:表在当前切片方向,是正向(向下/右)还是反向(向上/向左)处理
Reverse = [False, True, False, True]
for ms_conv, v, r in zip(self.message_passing, Vertical, Reverse):
x = self.message_passing_once(x, ms_conv, v, r)
return x
def message_passing_once(self, x, conv, vertical=True, reverse=False):
nB, C, H, W = x.shape
if vertical:
slices = [x[:, :, i:(i + 1), :] for i in range(H)]
dim = 2
else:
slices = [x[:, :, :, i:(i + 1)] for i in range(W)]
dim = 3
if reverse:
slices = slices[::-1]
out = [slices[0]]
for i in range(1, len(slices)):
out.append(slices[i] + F.relu(conv(out[i - 1])))
if reverse:
out = out[::-1]
return torch.cat(out, dim=dim)