【LaneNet】车道线检测代码复现过程

本文是LaneNet车道线检测效果复现,不涉及原理讲解部分。

开门见山,上链接:
链接:https://pan.baidu.com/s/1yxJNDdR1y4ixW62gDuDawQ
提取码:hcwi

关于LaneNet算法,网上有很多资料,Github上面也有很多,可能是自身检索能力有限,捣鼓了几天,迟迟不能复现代码的效果。主要原因就是某些文件找不到,下载不下来。现在相关文件均放在百度网盘里面了。

windows系统
python3.5.2
相关库具体版本见requirements_new.txt
更新时间2020.06.09

1.下载压缩包,解压,注意这里面的model文件夹下的New_Tusimple_Lanenet_Model_Weights权重文件是自己添加的,某些Github或者博客中并未提供。(为了这个权重,我真是费尽心思,现在分享在百度云盘里New_Tusimple_Lanenet_Model_Weights提取码:s40b)
在这里插入图片描述
在这里插入图片描述
2. 修改tools文件夹下的test_lanenet.py文件,添加相关路径,不然会报错。
在这里插入图片描述
修改成自己的路径

import sys 
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/config')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/data_provider')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/lanenet_model')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/tools')
  1. lanenet-lane-detection-master文件夹下,运行命令
python tools/test_lanenet.py --weights_path model/New_Tusimple_Lanenet_Model_Weights/tusimple_lanenet_vgg.ckpt  --image_path data/tusimple_test_image/0.jpg
  1. 注意pictures文件夹是我自己新建的,用于保存检测结果的图片,源码中没有这个文件夹。注意requirements_new.txt列出了我电脑装的一些库版本,和原作者版本有些出入,但是并不影响。注意data/tusimple_test_image文件夹保存有测试图片,测试效果很好。你也可以放自己的图片进行检测,但是我测试自己的车道线效果并不好,甚至说很差,原因暂时未知。
  2. 测试效果
    在这里插入图片描述
    上面测试效果还不赖,可是我换成自己的数据集,车道线就飞到天上了。。。
    在这里插入图片描述

——2020年6月11日更新
出现车道线飞到天上的原因找到了,是图片的分辨率不对。图片分辨率要求1280X720,而我的是1280X1024。分辨率调整过后,车道线检测就正常了。
在这里插入图片描述


2021年10月31日更新
有位博友需要做个简单的上位机,显示检测的结果,帮忙做了一个,现开源出来,供大家参考。

上位机样子大概就是下面这样,基于python做的,缺什么库pip什么库即可。
在这里插入图片描述
这里面涉及到两个文件,一个是my_form.py,为上位机界面程序;另一个是test_images.py,为车道线检测程序。
将这两个程序放在根目录下lanenet-lane-detection-master即可,运行test_images.py。可以修改图片路径,检测不同的图片。
test_images.py

import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from my_form import Ui_Form
from PIL import Image
from PIL.ImageQt import ImageQt
import qdarkstyle

import argparse
import os.path as ops
import time

import cv2
import glog as log
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import sys 
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/config')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/data_provider')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/lanenet_model')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/tools')

from config import global_config
from lanenet_model import lanenet
from lanenet_model import lanenet_postprocess


class MyMainForm(QMainWindow, Ui_Form):
    def __init__(self, parent=None):
        super(MyMainForm, self).__init__(parent)
        self.setupUi(self)

        # self.resize(600, 400)
        self.setWindowTitle("LaneNet车道线检测界面")
                            
        self.pushButton.clicked.connect(self.openimage)
        self.pushButton_2.clicked.connect(self.detection_results)

    def openimage(self):
        try:
            global image_path
            image_path = self.lineEdit.text()
            print(type(image_path))
        except:
            pass    

    def detection_results(self):

        CFG = global_config.cfg
        
        def args_str2bool(arg_value):
            if arg_value.lower() in ('yes', 'true', 't', 'y', '1'):
                return True

            elif arg_value.lower() in ('no', 'false', 'f', 'n', '0'):
                return False
            else:
                raise argparse.ArgumentTypeError('Unsupported value encountered.')


        def minmax_scale(input_arr):
            min_val = np.min(input_arr)
            max_val = np.max(input_arr)

            output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

            return output_arr

        def messageDialog():
            QtWidgets.QMessageBox.warning(self, "警告", "图片路径加载错误!", QtWidgets.QMessageBox.Cancel)
            

        def test_lanenet():
            try:
                # image_path = "data/tusimple_test_image/0.jpg"
                global image_path
                image_path = image_path
                weights_path = "model/New_Tusimple_Lanenet_Model_Weights/tusimple_lanenet_vgg.ckpt"
                assert ops.exists(image_path), '{:s} not exist'.format(image_path)
            except:
                messageDialog()

            log.info('Start reading image and preprocessing')
            t_start = time.time()
            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            image_vis = image
            image = cv2.resize(image, (512, 256), interpolation=cv2.INTER_LINEAR)
            image = image / 127.5 - 1.0
            log.info('Image load complete, cost time: {:.5f}s'.format(time.time() - t_start))

            input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

            net = lanenet.LaneNet(phase='test', net_flag='vgg')
            binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

            postprocessor = lanenet_postprocess.LaneNetPostProcessor()

            saver = tf.train.Saver()

            # Set sess configuration
            sess_config = tf.ConfigProto()
            sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
            sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
            sess_config.gpu_options.allocator_type = 'BFC'

            sess = tf.Session(config=sess_config)

            with sess.as_default():

                saver.restore(sess=sess, save_path=weights_path)

                t_start = time.time()
                binary_seg_image, instance_seg_image = sess.run(
                    [binary_seg_ret, instance_seg_ret],
                    feed_dict={input_tensor: [image]}
                )
                t_cost = time.time() - t_start
                log.info('Single imgae inference cost time: {:.5f}s'.format(t_cost))

                postprocess_result = postprocessor.postprocess(
                    binary_seg_result=binary_seg_image[0],
                    instance_seg_result=instance_seg_image[0],
                    source_image=image_vis
                )
                mask_image = postprocess_result['mask_image']
                
                
                for i in range(CFG.TRAIN.EMBEDDING_FEATS_DIMS):
                    instance_seg_image[0][:, :, i] = minmax_scale(instance_seg_image[0][:, :, i])
                embedding_image = np.array(instance_seg_image[0], np.uint8)
                        
                # 界面显示
                # 显示原始图片
                img_src = image_vis[:, :, (2, 1, 0)]
                img_src = cv2.cvtColor(img_src,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_2.width()
                label_height = self.label_2.height()
                # 将图片转换为QImage
                temp_imgSrc = QImage(img_src, img_src.shape[1], img_src.shape[0],img_src.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
                # 使用label进行显示
                self.label.setPixmap(pixmap_imgSrc)  

                # 显示原始图片
                img_src = image_vis[:, :, (2, 1, 0)]
                img_src = cv2.cvtColor(img_src,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_2.width()
                label_height = self.label_2.height()
                # 将图片转换为QImage
                temp_imgSrc = QImage(img_src, img_src.shape[1], img_src.shape[0],img_src.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_2.setPixmap(pixmap_imgSrc)  
                
                # 显示mask_image
                mask_image = mask_image[:, :, (2, 1, 0)]
                mask_image = cv2.cvtColor(mask_image,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_3.width()
                label_height = self.label_3.height()
                # 将图片转换为QImage
                temp_mask = QImage(mask_image, mask_image.shape[1], mask_image.shape[0],mask_image.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_mask = QPixmap.fromImage(temp_mask).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_3.setPixmap(pixmap_mask)

                # 显示embedding_image
                embedding_image = embedding_image[:, :, (2, 1, 0)]
                embedding_image = cv2.cvtColor(embedding_image,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_4.width()
                label_height = self.label_4.height()
                # 将图片转换为QImage
                temp_embed = QImage(embedding_image, embedding_image.shape[1], embedding_image.shape[0],embedding_image.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_embed = QPixmap.fromImage(temp_embed).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_4.setPixmap(pixmap_embed)

                # 二值化
                gray_image = cv2.cvtColor(mask_image, cv2.COLOR_RGB2GRAY)
                ret, binary_seg_image = cv2.threshold(gray_image, 10, 255, cv2.THRESH_BINARY)
                label_width = self.label_5.width()
                label_height = self.label_5.height()
                temp_QtImg = QImage(binary_seg_image.data,binary_seg_image.shape[1],binary_seg_image.shape[0],binary_seg_image.shape[1],QImage.Format_Indexed8)
                pixmap_QtImg = QPixmap.fromImage(temp_QtImg).scaled(label_width, label_height)
                self.label_5.setPixmap(pixmap_QtImg)

            sess.close()

            return

        test_lanenet()
                
if __name__ == "__main__":
    app = QtWidgets.QApplication(sys.argv)
    app.setStyleSheet(qdarkstyle.load_stylesheet())
    myWin = MyMainForm()
    myWin.show()
    sys.exit(app.exec_())

my_form.py

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'untitled.ui'
#
# Created by: PyQt5 UI code generator 5.15.2
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.


from PyQt5 import QtCore, QtGui, QtWidgets


class Ui_Form(object):
    def setupUi(self, Form):
        Form.setObjectName("Form")
        Form.resize(888, 561)
        self.label_6 = QtWidgets.QLabel(Form)
        self.label_6.setGeometry(QtCore.QRect(310, 10, 281, 29))
        self.label_6.setStyleSheet("font: 16pt \"Ubuntu\";")
        self.label_6.setAlignment(QtCore.Qt.AlignCenter)
        self.label_6.setObjectName("label_6")
        self.layoutWidget = QtWidgets.QWidget(Form)
        self.layoutWidget.setGeometry(QtCore.QRect(20, 50, 851, 501))
        self.layoutWidget.setObjectName("layoutWidget")
        self.gridLayout_2 = QtWidgets.QGridLayout(self.layoutWidget)
        self.gridLayout_2.setContentsMargins(0, 0, 0, 0)
        self.gridLayout_2.setObjectName("gridLayout_2")
        self.gridLayout = QtWidgets.QGridLayout()
        self.gridLayout.setContentsMargins(-1, -1, -1, 10)
        self.gridLayout.setObjectName("gridLayout")
        self.label_3 = QtWidgets.QLabel(self.layoutWidget)
        self.label_3.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_3.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_3.setAlignment(QtCore.Qt.AlignCenter)
        self.label_3.setObjectName("label_3")
        self.gridLayout.addWidget(self.label_3, 0, 2, 1, 1)
        self.label_2 = QtWidgets.QLabel(self.layoutWidget)
        self.label_2.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_2.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_2.setAlignment(QtCore.Qt.AlignCenter)
        self.label_2.setObjectName("label_2")
        self.gridLayout.addWidget(self.label_2, 0, 1, 1, 1)
        self.label_4 = QtWidgets.QLabel(self.layoutWidget)
        self.label_4.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_4.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_4.setAlignment(QtCore.Qt.AlignCenter)
        self.label_4.setObjectName("label_4")
        self.gridLayout.addWidget(self.label_4, 1, 1, 1, 1)
        self.label = QtWidgets.QLabel(self.layoutWidget)
        self.label.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label.setAutoFillBackground(False)
        self.label.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label.setAlignment(QtCore.Qt.AlignCenter)
        self.label.setObjectName("label")
        self.gridLayout.addWidget(self.label, 0, 0, 2, 1)
        self.label_5 = QtWidgets.QLabel(self.layoutWidget)
        self.label_5.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_5.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_5.setAlignment(QtCore.Qt.AlignCenter)
        self.label_5.setObjectName("label_5")
        self.gridLayout.addWidget(self.label_5, 1, 2, 1, 1)
        self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1)
        self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
        self.horizontalLayout_2.setContentsMargins(-1, -1, -1, 10)
        self.horizontalLayout_2.setSpacing(16)
        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
        self.label_7 = QtWidgets.QLabel(self.layoutWidget)
        self.label_7.setObjectName("label_7")
        self.horizontalLayout_2.addWidget(self.label_7)
        self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget)
        self.lineEdit.setObjectName("lineEdit")
        self.horizontalLayout_2.addWidget(self.lineEdit)
        self.gridLayout_2.addLayout(self.horizontalLayout_2, 1, 0, 1, 1)
        self.horizontalLayout = QtWidgets.QHBoxLayout()
        self.horizontalLayout.setContentsMargins(-1, -1, -1, 0)
        self.horizontalLayout.setSpacing(16)
        self.horizontalLayout.setObjectName("horizontalLayout")
        spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem)
        self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton.setStyleSheet("background-color: rgb(114, 159, 207);")
        self.pushButton.setObjectName("pushButton")
        self.horizontalLayout.addWidget(self.pushButton)
        spacerItem1 = QtWidgets.QSpacerItem(308, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem1)
        self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton_2.setStyleSheet("background-color: rgb(114, 159, 207);")
        self.pushButton_2.setObjectName("pushButton_2")
        self.horizontalLayout.addWidget(self.pushButton_2)
        spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem2)
        self.gridLayout_2.addLayout(self.horizontalLayout, 2, 0, 1, 1)
        self.gridLayout_2.setRowStretch(0, 7)
        self.gridLayout_2.setRowStretch(1, 1)
        self.gridLayout_2.setRowStretch(2, 1)

        self.retranslateUi(Form)
        QtCore.QMetaObject.connectSlotsByName(Form)

    def retranslateUi(self, Form):
        _translate = QtCore.QCoreApplication.translate
        Form.setWindowTitle(_translate("Form", "Form"))
        self.label_6.setText(_translate("Form", "LaneNet车道线检测界面"))
        self.label_3.setText(_translate("Form", "IMAGES2"))
        self.label_2.setText(_translate("Form", "IMAGES1"))
        self.label_4.setText(_translate("Form", "IMAGES3"))
        self.label.setText(_translate("Form", "IMAGES"))
        self.label_5.setText(_translate("Form", "IMAGES4"))
        self.label_7.setText(_translate("Form", "图片路径"))
        self.lineEdit.setText(_translate("Form", "data/tusimple_test_image/0.jpg"))
        self.pushButton.setText(_translate("Form", "加载图片"))
        self.pushButton_2.setText(_translate("Form", "处理结果"))
复现"ultra fast structure-aware deep lane detection"代码,首先需要了解该算法的原理和网络结构。该算法是一种深度学习方法,用于车道线检测。其核心思想是结合结构感知机制和快速推理策略,以实现高效、准确的车道线检测。 为了复现该算法,需要完成以下步骤: 1. 数据集准备:收集车道线数据集并进行相应的标注。可以使用公开数据集,如CULane或TuSimple等,或者自己采集数据集。数据集应包含车道线图像以及对应的标注信息。 2. 网络结构构建:根据论文中提到的网络结构,构建模型。根据论文中的说明,可以选择使用FCN、UNet等结构。确保灵活地调整网络的深度和宽度,以适应不同的数据集和性能要求。 3. 损失函数定义:根据论文中的介绍,选择适当的损失函数,如二分类交叉熵损失函数等,以最小化预测标注和真实标注之间的差异。 4. 数据预处理:对输入图像进行预处理,如图像归一化、resize等,以适应网络的输入要求。 5. 模型训练:使用准备好的数据集和网络结构,进行模型的训练。设置合适的超参数,如学习率、批大小等。通过迭代优化网络参数,使模型逐渐学习到车道线的特征。 6. 模型评估:使用测试集对模型进行评估,计算准确率、召回率、F1得分等指标,以评估模型的性能。 7. 代码测试:使用测试集对复现代码进行测试,观察模型的预测结果。可进行可视化展示,比较模型的预测结果与真实标注的差距。 8. 优化和改进:根据测试结果和需要,对网络结构、超参数等进行调整和优化,进一步提升模型性能。 通过以上步骤,就可以较为全面地复现"ultra fast structure-aware deep lane detection"代码,从而实现高效、准确的车道线检测算法。
评论 81
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值