基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现

之前读过这篇论文,导师说要复现,这里记录一下。废话不多说,再重读一下论文。
注:非一字一句翻译。个人理解,一定偏颇。

基于深度强化学习的车道检测和定位

官方源码下载:https://github.com/tuzixini/DQLL
论文原文:https://www.sciencedirect.com/science/article/pii/S0925231220310833
如需引用:
@article{zhao2020deep,
title={Deep reinforcement learning based lane detection and localization},
author={Zhao, Zhiyuan and Wang, Qi and Li, Xuelong},
journal={Neurocomputing},
volume={413},
number={6},
pages={328-338},
doi={10.1016/j.neucom.2020.06.094},
year={2020}
}

摘要

基于深度学习的车道检测方法只检测带有粗略边框的车道线,而忽略了特定曲线车道的形状。针对上述问题,本文将深度强化学习引入粗车道检测模型中,以实现精确的车道检测和定位。该模型由**边界盒探测器(bounding box detector)和地标点定位器(landmark point localizer)**两个阶段组成。边界盒级卷积神经网络车道检测器以边界盒的形式输出车道的初始位置。然后,基于强化学习的深度Q-Learning定位器(Deep Q-Learning Localizer,DQLL)将车道作为一组地标进行精确定位,以更好地表征曲线车道。构造并发布了一个像素级车道检测数据集NWPU车道数据集。它包含了各种真实的交通场景和精确的车道线遮罩。该方法在发布数据集和存储数据集上都取得了较好的性能。

1 引言

避免事故发生和引导车辆沿着适当的车道行驶是辅助系统的两项基本任务,实现这两个目标的几个技术手段:车道检测,道路检测,前方车辆碰撞预警,交通标志检测,交通拥堵检测,道路标记检测。车道检测在上述任务和其他高级驾驶辅助目标中有着不可替代的作用,如可行驶区域检测和自动泊车。图1显示了表示车道的不同的方法,包括直线、边框、地标和像素掩码。
图1
在深度卷积神经网络(Deep Convolutional Neural Networks, DCNN)广泛研究和应用之前,很多工作都是使用低级特征提取器来检测车道线,并使用多条直线来表示车道。直线在直线车道上很好,但在曲线车道上就不行了。为了解决曲线车道的表示问题,在车道检测中引入边界盒和像素级掩码。但是,边界盒的精度不够高,像素级掩模的预测需要复杂的计算。
为了解决上述问题,我们提出了一种基于深度强化学习的车道检测和定位网络。它由深度卷积巷边界盒检测器(deep convolutional lane bounding box detector)和深度q学习定位器(Deep Q-Learning localizer)组成。所提网络的结构示意图如图2所示。
图2
它是一个两阶段的顺序处理架构。具体来说,第一阶段是一个改进的Faster R-CNN[27],它以包围盒的形式检测道。第二阶段为轻量化深度Q-learning地标定位器,由五层卷积层和三层全连接层组成。
在检测阶段得到边界盒后,初始地标沿边界盒对角线均匀分布。然后车道定位任务变成了一个点移动博弈游戏。车道定位器在游戏中扮演Agent的角色。agent需要做的是根据当前环境状态将地标向特定方向移动当前环境状态包括当前点位置、动作历史向量和已编码的图像特征。最后,当agent决定不再移动路标点时,将所有路标点的位置输出为车道的位置。
为了验证所提方法的有效性,我们建立了一个名为NWPU Lanes dataset的像素级车道数据集,该数据集包含1964个交通场景图像,并带有标记良好的像素级车道遮罩。
contributions:

  • 定义了一种新的车道检测和定位表示方法,达到了精度和计算量的平衡。
  • 深度Q-Learning车道定位器(DQLL)将车路定位为一组地标,对曲线车道进行了较好的表征。
  • 构建一个像素级车道数据集NWPU车道数据集,其中包含精心标注的城市图像,有助于发展交通场景的理解。

2 相关工作

2.1 传统车道线检测方法

  • 为车道线构造易于识别的特征,根据其相同的特征手工设计特征表示。常用的特征提取器如Hough变换[21]和Dark-Light-Dark
    (DLD)[22]在简单的条件下是有效的,但是在复杂的场景下性能会迅速下降。它们对噪音的敏感导致了这个问题。
  • 逆透视映射(IPM)[28]将原始图像转换为鸟瞰图。然后使用上面的特性提取器在这个视图下生成特性。视图转换有助于减少冗余信息并增强目标表示。但在非常复杂的情况下,其效果明显下降。其根本原因是DLD、霍夫变换等低级提取器提取的特征不够强大。

2.2 基于深度学习的车道线检测

DCNN可以从输入图像中生成具有足够高水平语义信息的特征。此外,它的自动拟合特性节省了大量的特征设计工作。

2.3 强化学习

Mnih等[39]将Q-Learning与deep Q-Learning Network (DQN)中的深度学习方法相结合,即使用神经网络代替Q-table。

3 方法论

3.1 概述

本方法由检测和定位部分组成。检测部分的目的是获得车道的初步边界盒位置。为了配合下一阶段的定位过程,我们仔细考虑了车道的特点,通过观察如图3所示的各种车道,我们总结如下:

  • 边界框框出的车道线总是靠近矩形对角线
  • 左上到右下:视野的左边;左下到右上:视野的右边 (我总感觉这里写反了的样子。。)*
  • 边界框的对角线大致可以用来表示直线的位置。对于弯曲车道,由于车道形状的巨大差异,它失败了。
    车道所经过的对角线将是确定路标点初始位置的关键因素。

图3

3.2 车道线检测

首先在车道数据集上对改进后的公共目标检测器进行再训练,并将其用于获得车道的初始位置。从技术上讲,几乎所有典型的对象检测器,如[40-43]都可以在这里使用。**第一阶段我们采用Faster R-CNN作为基线车道边界盒检测器。**通过检测车道坡度,将车道划分为不同类型,并将车道类型与3.1节讨论的内容统一起来。检测阶段完整工作流程如图4所示。
图4
Faster R-CNN使用CNN完成proposal生成、回归和分类。一个输入图像在网络中只传播一次,提高了网络的效率。
用VGG作为CNN的骨架。它从输入的三帧RGB图像中提取卷积特征映射。然后将整个图像特征向量和图像信息发送到RPN,生成区域建议。ROI pooling层有助于将区域建议的特征向量强制为固定大小。建议回归网络和建议分类网络分别使用多个全连通层来得到边界框偏差和分类概率。网络的详细结构如图4所示,其中Conv表示卷积层,Dense表示完全连通层。
车道检测阶段的最终输出为输入图像内所有车道的边界框位置和车道类型。

3.3 车道线定位

我们使用五个地标点来准确定位车道。地标在边界框内统一初始化。这样,定位阶段就变成了一个点移动博弈游戏,目标是将所有的地标移动到正确的位置。应用一种基于强化学习的深度Q-learning车道定位器来进行游戏。与边界框相比,地标有效地提高了曲面车道的表示能力,提供了更精确的位置信息。

3.3.1 游戏定义

如图3所示,经检测阶段的每个包围盒与盒子在水平方向上通过5条截止线分割成6个相等的区域。车道线所沿的对角线与这五条分割线在几个点相交。这些点被用作地标点的初始位置。
我们尝试通过深度强化学习方法来解决点定位博弈问题。这里使用的学习策略是Q-Learning[38]方法。在原来的Q表中,它对每个不同的环境状态进行了重新编码,哪个行为选择会导致最高的回报。初始Q表给出随机的行动决策,它根据以下公式随训练过程更新:
在这里插入图片描述
(关于公式的理解和别的地方一样,这里不再赘述。)
除了Q表之外,环境状态、行动选项和奖励功能共同构成了深度Q学习的过程。下面的小节将详细介绍这三个关键组件。

3.3.2 环境状态

环境状态包含了影响行动决策结果的因素。对于这个移动点游戏,当前选择的地标点的位置信息,以及图像块都有助于找到正确的位置。我们还考虑了之前已经做过的动作,我们称之为动作历史向量。(就是引言里说过的三部分组成:当前点位置、动作历史向量和已编码的图像特征)
在这里插入图片描述
S是当下环境状态,等式右侧第一项是已编码的图像特征(Ib是边界盒框起来的部分),第二项是当前点位置,第三项是动作历史向量。中间的符号表示concatenate操作。
图5表示了环境状态的组成。
图5

3.3.3 动作空间

地标点的纵坐标是一个固定的值,所以它只能水平移动。我们人为地定义了三种可选的操作类型。

  • delete action:agent决定删除当前点或采取其他操作。偏离范围或距离实际车道位置太远的点可能被删除。
  • moving action:对正常范围下的地标点,agent将点移向正确的方向,这些点沿水平线有两个移动方向,因此移动动作包含向左或向右的运动。
  • terminal action:当点与期望位置足够接近时,agent必须判断当前位置是否为最终位置。终端动作决定截断点移动过程或进入下一个动作选择。

所有的动作选择以及相应的实际像素级点移动如表1所示。其中x表示当前地标点的位置。
表1

3.3.4 奖励函数

我们根据行动选择所导致的结果将其分为三种类型。

  • Invalid Action Choices:动作a将地标点移出了适当的图像范围,删除了应该保留的点或保留了应该删除的点。

在这里插入图片描述

  • Regular Action Choices:如果这个动作选择不是前面提到的无效的,而是一个移动的动作,我们称这个选择为常规的动作选择。我们定义当前点位置之间的距离和环境状态下点的真实位置为d(s)。
    d(s’)新距离
    在这里插入图片描述

点离真实距离更近得一分,否则扣一分。

  • Terminal Action Choices:agent做出终止点移动过程的决策。当当前点和地面真实位置足够接近时,agent通过terminal action choices获得正分数。否则,如果代理在不合适的时间停止移动过程,则会得到负的分数。

在这里插入图片描述

3.3.5 概述

完整的车道定位工作流程如图6所示。
图6
首先将检测到的所有边界框从完整图像中截断,然后将其调整为统一的尺寸**[100, 100, 3]**,然后再送到Deep Q-Learning Localizer(DQLL)网络中。定位器根据边界盒的类型初始化5个地标点,即5个地标的水平和垂直坐标均匀分布在0 ~ 100之间。此外,这五个点对于不同的车道可能沿着不同的斜线。初始化完成后,分别对5个地标点进行定位。
图6显示了右侧的决策网络结构。具体的网络架构如表2所示。
在这里插入图片描述
“Conv1: (k(3,3),c(3->48),s(1),NoPadding)”:意味着卷积层名为Conv1使用48个卷积核的大小3x3x3与strides=2 (这里我认为该是1)和nopadding。
‘‘FC: 5393 -> 512” :全连接层,输入尺寸为5393,输出尺寸为512。
动作历史向量的长度影响第一完全连接层的输入端形状。这里我们使用四个过去的动作来形成动作历史向量。因此,第一个完全连接层的输入大小为1 x 21 x 256 + 1 + 4 x 4 = 5393。
1 x 21 x 256:特征编码器的输出
1 + 4 x 4:四个步长的动作历史向量
(我认为1是当前点位置)

损失函数:MSE
在这里插入图片描述

4 数据集

本文提出具有像素级别标签的NWPU数据集,它来自于真实驾驶场景下录制的视频。
由于人工采集和标注真实场景数据比较困难,我们从数据集[47]中选择一些虚拟数据来辅助我们的训练过程。这些图像和标签掩码都是由带有精确注释的软件生成的合成数据。

4.1 NWPU车道线数据集

汽车视频数据记录器收集了13个真实驾驶场景的视频。其中12个片段是3分钟长,剩下的一个是1分38秒长。经过每秒1帧的采样,总共得到2258张初始图像。通过对这些图片的观察,我们发现在实际驾驶过程中,由于车辆停车、拥堵、遮挡等问题,仍然有很多图片无法用于初步样本。所以我们又手动删除了一张图片,删除后保留了1964张图片。接下来,我们使用自己开发的打标工具进行准确的像素级线打标。标记完成后,可以通过像素级标记生成包围盒。
图7左侧的部分(A)展示了来自NWPU lane数据集的驾驶场景及其相应的掩码。最终得到的数据分为训练集和测试集,其中测试集占20%。数据分布如表3所示。
图7

4.2 合成数据集

手工采集的数据在标注工作中会不可避免的出现误差。在车道线定位的过程中,输入图像是一个小尺寸的图像,只包含从原始图像中截断的车道线区域。因此,在定位过程中,原始图像级的误差可能会放大。与人工标注相比,软件生成的虚拟数据具有完全准确的像素级标注。因此,我们不仅使用自己构建的真实数据集,还从其他数据集中选择适当场景下的虚拟数据进行训练和测试。SYNTHIA数据集包含大量生成的数据,这些数据是在不同的场景、时间、季节和天气设置中构建的。我们手动选择一些接近NWPU lane数据集的场景。这些虚拟场景数据的加入有效地促进了DQLL的学习过程。我们还将这些数据拆分为列车集和测试集,拆分规则与自构建数据集一致,详细信息如表3所示。
表3
(表格两行好像搞反了)

5 实验和讨论

5.1 评价指标

定义了两个评价指标,这两个指标的定义与第3节车道线定位的具体实现密切相关。这两个指标都是在一个完整的测试过程中定义的。

  • Hit Rate α:反映定位精度。‘‘Hit Point”:最终终止点与ground truth位置点之间的距离小于5像素。

在这里插入图片描述
分子:在特定测试期间命中的总点数
分母N:所有地标点的个数,也等于5倍的边界盒数
0<α<1。愈高愈好

  • Average Step:反应定位速度。

在这里插入图片描述
S:本测试期内所有地标点的行动步骤总数,
N:同上式。

5.2 实验设置

NVIDIA GTX 1080Ti GPU
Inter Core i7-6800K@3.4 GHz CPU
Ubuntu 14.10
TensorFlow [48] or PyTorch
两个阶段依赖于相当独立的实验设置。
边界盒检测阶段:在基于像素级lane数据集的基础上,采用简单的连通分量检测算法生成检测阶段所需的边界盒。Faster R-CNN在训练过程中,批大小为1,边界盒分类的批大小为300。学习率和权值衰减分别设置为0.001和0.0005。
定位阶段:在DQLL的训练过程中,需要与检测阶段不同的地面真实数据。我们进一步对第一步的边界框数据进行处理。首先,对每个矩形盒分别从原始图像中截断;然后,每个盒子被5个分割器分成6个等面积的水平矩形。分隔线与车道线相交于一小段。**最后,将短截面的中点作为对应地标点的ground truth。**直到此时,训练DQLL的数据才准备好。定位阶段的学习率设置为0.0001,批量大小为1024进行训练。

5.3 实验结果

验证DQLL的效果,在NWPU 和 TuSimple 数据集分别进行验证。

  • 于NWPU数据集:表4给出了DQLL与检测、分割等车道检测方法相结合的测试结果。如我们所料,基于分割的初始精度高于基于检测的初始精度。

表4

  • 于TuSimple数据集:3626张用于训练,2782张用于测试。我们将标记的数据转换为DQLL训练和测试所需要的形式。表5显示了不同方法的结果。

表5
TuSimple Lane数据集的总命中率高于NWPU Lane数据集。这意味着前者有更多的直线。实验结果表明,平均步数随着第一阶段初始检测效果的提高而减少。因为更可靠的检测结果可以在初始化地标点时提供更好的指导。
表6
从表4-6所示的实验结果可以看出,无论使用何种实验设置,DQLL都可以在四步之内完成一个地标点的定位过程,对于NWPU Lanes数据集,平均只需要三步,对于TuSimple Lane数据集,平均只需要不到两步。显然,动作历史向量的长度会影响命中率和平均步长。选择合适的长度变得很重要,我们在5.4小节中进行具体分析。

  • DQLL 可视化
    图8展示了DQLL定位过程的可视化。所有子图由五行组成。每一行对应于一个地标点的定位过程。蓝色圆圈是地面真相地标的置信区间,绿色的点是当前移动的地标点。因此,对于每一个子图,左上角是初始点位置,右下角是定位过程的最终输出。子图(A)和©展示了正常曲线车道线的细节定位过程。与初始地标位置相比,DQLL有效地移动了地标以更好地拟合曲线。子图(B)是一条直线,但偏离了边界盒的中心,因此去掉了两个地标点,只剩下三个点。第四个子图给出了一个反面的例子,五个标志性点的初始位置在预期范围内,但是DQLL错误地将其移出了预期范围。
    图8a
    图8b

图8c
图8d
在这里插入图片描述

5.4 动作历史向量的影响

动作历史向量的长度设置为范围为[0, 10],步长为2。所有实验结果如表6所示。
与没有动作历史向量的DQLL相比,不管历史向量有多长,有了这个向量的DQLL都提高了定位精度,减少了平均需要的步长。这说明过去的动作选择在一定程度上有助于当前状态下的决策。然而,并不是越长越好。历史向量长度为4时,命中率最好;历史向量长度为6时,平均步长最好。这两个评估指标都倾向于随着动作历史向量长度的增加而从增加变为减少。因此,选择适当长度的历史步骤是至关重要的。

5.5 与监督学习方法对比

我们手工设计了几种深度监督学习(DSL)方法,这些模型具有与DQLL完全相同的网络结构。唯一的区别是这些DSL方法尝试直接回归拟合五个地标点的位置,并且它们的训练损失函数不同。在此,我们分别使用MSE、L1和 smooth L1损失函数来训练DSL模型。实验结果如表7所示。图9说明了所比较的DSL模型的网络架构。
表7
图9
比较结果表明,学习如何移动地标点到正确的位置比直接返回地标点的位置更有效。

6 结论

在未来,我们将尝试利用更多的先验知识来提高检测性能。

代码解读

$CODEROOT:放置此代码的路径。
$DATAROOT:放置数据集的路径。

准备

Python 3.x
Pytorch 1.x

确保你的代码目录如下:
在这里插入图片描述
在这里插入图片描述

下载图森数据集

我们需要的是来自“LANE DETECTION CHALLENGE”的数据。
在这里插入图片描述
确保你的**$DATAROOT**目录如下:
在这里插入图片描述
使用genMyData.pygenMeanImg.py生成新数据。

  1. 修改变量DATAROOT(文件getMyData.py , 210行)。到您使用的实际的$DATAPATH(这里是r"/opt/disk/zzy/dataset/TuSimple")。
DATAROOT = r"/opt/disk/zzy/dataset/TuSimple"
  1. 转到$CODEPATH并运行genMyData.py,等待完成。
cd $CODEPATH
python genMyData.py

genMyData.py

# coding=utf-8
# ----tuzixini@gmail.com----
# WIN10 Python3.6.6
# 用途: 处理TuSimpleLane 数据集
# 生成需要的数据
# genMyData.py

# 导入模块
import json
import os
import os.path as osp
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import shutil
from tqdm import tqdm

faileList = []

def doit(sroot, jlistpath, droot,namelist=[],DRL_list=[],DRLcount=0):
    os.makedirs(droot, exist_ok=True) # 创建目标目录
    os.makedirs(osp.join(droot, 'img'), exist_ok=True)
    os.makedirs(osp.join(droot, 'mask'), exist_ok=True)
    os.makedirs(osp.join(droot, 'mask_color'), exist_ok=True)
    os.makedirs(osp.join(droot, 'bbox'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL','ori'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL', 'resize'), exist_ok=True)
    jlist =[]
    with open(jlistpath, 'r') as f: # 把Json文件按行读入
        for line in f.readlines():
            jlist.append(json.loads(line)) # 放到jlist中(第一步共有358行)
    for ins in tqdm(jlist):
        faileIns = dict()
        imgpath = ins['raw_file']
        temp = imgpath.split('/')
        newname = temp[1] + temp[2] + temp[3][:-4]
        namelist.append(newname)
        spath = osp.join(sroot, imgpath) # 源目标文件路径
        # 计算
        try:
            temp = getInfo(spath, ins)
        except:
            faileIns['sroot'] = sroot
            faileIns['jlistpath'] = jlistpath
            faileIns['ins'] = ins
            faileList.append(faileIns) # 若程序运行不正常,faileList存储报错次数(根据笔者半天时间亲测:运行的第一个文件中第281行存在错误。https://blog.csdn.net/songyuc/article/details/109769131   存在一个边界框,对y=91时没有车道线,所以会报错。其他文件也存在一些报错,不过不是很多,不影响运行
            continue
        mask, mask_color, bbox, box, box_mask, gt, rebox, rebox_mask, regt = temp
        if len(box) > 5:
            faileIns['sroot'] = sroot
            faileIns['jlistpath'] = jlistpath
            faileIns['ins'] = ins
            faileIns['Reason']="cot>5"
            faileList.append(faileIns) # 若box(原检测出图片中车道线数量>5(存在误检,实际上不一定大于5,则记录他们源目录、json文件目录、边界框点集、原因)
            continue
        #  copyimg
        dpath = osp.join(droot, 'img', newname + '.jpg')
        shutil.copy(spath, dpath) # 将spath的文件复制到dpath。详见https://www.cnblogs.com/liuqi-beijing/p/6228561.html
        # mask
        dpath = osp.join(droot, 'mask', newname + '.png') 
        mask = Image.fromarray(mask.astype('uint8'))# 实现array到image的转换。详见https://blog.csdn.net/weixin_39450145/article/details/103874310
        mask.save(dpath) # 生成图片并保存(全是黑色)【有车道线地方为1,无车道线地方定义为0】
        # mask_color
        dpath = osp.join(droot, 'mask_color', newname + '.png')
        mask_color = Image.fromarray(mask_color.astype('uint8'))
        mask_color.save(dpath) # 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】
        # bbox
        dpath = osp.join(droot, 'bbox', newname + '.json')
        with open(dpath,'w') as f:
            json.dump(bbox,f) # 注:bbox是四个点的形式存在。有时可能会将一条车道线检测为两条,不过不影响后续定位。
        # box 裁剪出来的图片
        for i in range(len(box)):
            # 获取裁剪出来图片的名称
            DRLname = newname + '_' + str(DRLcount)
            DRL_list.append(DRLname)
            # box
            temp = box[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot,'DRL','ori',DRLname+'.png')
            temp.save(dpath) # 依次检测原图片中的每一条车道线
            # boxmask
            temp = box_mask[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'ori', DRLname+'_mask.png')
            temp.save(dpath) # 掩码,基本黑色(有车道线地方为1,无车道线地方定义为0)
            # boxmask_color
            boxmask = box_mask[i]
            temp = np.zeros(boxmask.shape)
            temp[boxmask == 1] = 255
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'ori', DRLname + '_mask_color.png')
            temp.save(dpath) # 掩码 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】
            # gt
            dpath = osp.join(droot,'DRL','ori',DRLname+'.json')
            with open(dpath, 'w') as f:
                json.dump(gt[i], f) # 真实车道线类别和坐标
            # rebox
            temp = rebox[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'.png')
            temp.save(dpath) # 生成缩放后的边界框图片(从原图像截取后再缩放) 100x100
            # reboxmask
            temp = rebox_mask[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'_mask.png')
            temp.save(dpath) # 掩码,基本黑色(有车道线地方为1,无车道线地方定义为0) 100x100
            # reboxmask_color
            boxmask = rebox_mask[i]
            temp = np.zeros(boxmask.shape)
            temp[boxmask == 1] = 255 # 将rebox_mask为1的像素点置为255
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize',DRLname + '_mask_color.png')
            temp.save(dpath) # 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】 100x100
            # regt
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'.json')
            with open(dpath, 'w') as f:
                json.dump(regt[i], f) # 经缩放后真实地标点坐标(re ground truth)
            DRLcount += 1
    return namelist,DRL_list,DRLcount # namelist:文件名称列表(图片数量),DRL_list:所有图片的车道线(可能重复),DRLcount:对DRL_list进行计数。


def getInfo(ipath,data): # 传入图片和json文件中的对应行
    img = Image.open(ipath)
    img = np.array(img)
    img_t = np.zeros(img.shape)
    mask = np.zeros((img.shape[0], img.shape[1]))
    mask_color = np.zeros((img.shape[0], img.shape[1]))
    gt_lanes_vis = [[(x, y) for (x, y) in zip(lane, data['h_samples'])if x >= 0] for lane in data['lanes']] # 存储每条车道的坐标信息

# 依每条车道的坐标画图线(如图7的黑白图片)
    for lane in gt_lanes_vis:
        cv2.polylines(img_t, np.int32(
            [lane]), isClosed=False, color=(0, 255, 0), thickness=5) # img_t存贮车道线坐标(拟合一条线)
    mask_color[img_t[:, :, 1] == 255] = 255 # 生成掩码(有车道线的地方置为255,没有则是0)
    mask[img_t[:, :, 1] == 255] = 1  # 生成掩码(有车道线的地方置为1,没有则是0)
    # 计算bbox
    temp = Image.fromarray(mask.astype('uint8')) # 把mask转换为图像格式
    temp = np.array(temp) # 图像 到 矩阵
    bbox = []
    box = []
    box_mask = []
    gt = []
    rebox = []
    rebox_mask = []
    regt = []
    temp, cons, hier = cv2.findContours(temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 需注意不同版本opencv返回值的不同。**https://www.cnblogs.com/guobin-/p/10842486.html**    博主在这里卡了一天,汇报又要挨骂了。。。OpenCV2和OpenCV4中是两个返回值,OpenCV3中有三个返回值
    for con in cons: # 对轮廓的点集进行遍历
        x, y, w, h = cv2.boundingRect(con) # 边界框
        # 截取原图
        box.append(img[y:y+h, x:x+w, :].astype('uint8'))
        # 截取mask
        temp = np.zeros(img.shape) # temp是掩码
        cv2.drawContours(temp, [con], 0, (0, 0, 255), -1) # 在temp上画上边界框
        ttemp = np.zeros((img.shape[0], img.shape[1])) 
        ttemp[temp[:, :, 2] == 255] = 1# temp有边界框地方置为1,其他地方为0
        box_mask.append(ttemp[y:y+h, x:x+w].astype('uint8')) # 截取边界框区域
        # 计算bbox [{'class':cl,'points',[x1,y1,x2,y2]},{}]
        temp = ttemp[y:y+h, x:x+w]
        [vx, vy, xx, yy] = cv2.fitLine(con, cv2.DIST_L2, 0, 0.01, 0.01) # https://blog.csdn.net/lovetaozibaby/article/details/99482973         前两维代表拟合出的直线的方向,后两位代表直线上的一点。(即通常说的点斜式直线)         (vx,vy)是与直线共线的标准化向量(x,y)是线上的一个点
        slope = -float(vy)/float(vx)
        if slope <= 0:
            # 左上 到右下 (右侧)
            cl = 0
        else:
            # 左下到右上(左侧)
            cl = 1
        ttemp = dict()
        ttemp['points'] = [int(x), int(y), int(x+w), int(y+h)]
        ttemp['class'] = cl
        bbox.append(ttemp) # bbox存贮边界盒
        # 计算gt[{'class':cl,'gt':[x1,x2,x3,x4,x5]},{}]
        ttemp = dict()
        ttemp['class'] = cl
        initY = []
        for i in range(5):
            initY.append(int((i+1)*(h/6))) # 把五个纵坐标找到
        initX = []
        for y in initY: # ground truth的定义见文章5.2节
            xx = temp[y, :]
            xx = np.where(xx == 1)
            x = int((np.max(xx)+np.min(xx))/2)
            initX.append(x)
        ttemp['gt'] = initX
        gt.append(ttemp) # gt存储真实五个点X信息
    # 生成resize的DRL材料**(我认为是把图片给缩放成100x100尺寸【rebox】,掩码也缩放【rebox_mask】,也把ground truth点缩放【regt】)**
    for i in range(len(box)):
        temp = box[i].copy() # box存贮经截取了的图片(一张图片中几条车道线几个box)
        temp = cv2.resize(temp, (100, 100))
        rebox.append(temp)

        temp = box_mask[i].copy()
        temp = cv2.resize(temp, (100, 100))
        rebox_mask.append(temp)
        # pdb.set_trace()

        ttemp = dict()
        ttemp['class'] = gt[i]['class']
        initY = [11, 31, 51, 71, 91]
        initX = []
        for y in initY:
            xx = temp[y, :]
            xx = np.where(xx == 1)
            x = int((np.max(xx)+np.min(xx))/2)
            initX.append(x)
        ttemp['gt'] = initX
        regt.append(ttemp)
    result = [mask, mask_color, bbox, box, box_mask, gt, rebox, rebox_mask, regt]
    return result # 最终的返回值:mask(黑色掩码【背景】),mask_color(白色掩码【车道线】),bbox(边界盒【四个点坐标+类别),box(经截取了的图片),box_mask(截取了的图片的掩码,黑色背景),gt(真实地标点坐标),rebox(经缩放的边界盒),rebox_mask(经缩放的黑色背景掩码),regt(经缩放的真实地标点坐标)

DATAROOT = r"/home/wqf/tusimple" # 你自己下载的tusimple数据集目录所在

# 测试数据集
sroot = 'train_set' 
jlistpath = r'train_set/label_data_0531.json'
droot = 'MyTuSimpleLane/train'
sroot = osp.join(DATAROOT, sroot) # 源目录
jlistpath = osp.join(DATAROOT, jlistpath) # json文件目录
droot = osp.join(DATAROOT,droot) # 目标目录
namelist, DRL_list,DRLcount = doit(sroot, jlistpath, droot, namelist=[], DRL_list
  • 7
    点赞
  • 81
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值