道路线检测 lanenet_pytorch

最近在学习道路线检测,在网上找了很多资源,在github上找到的开源的LaneNet项目数目较少,只有基于tensoflow 1.x,但是在配置环境的时候过于麻烦,同时由于tensorflow 2.x的缘故,在源代码上修改的时候也挺麻烦的,动不动就报错,最后也没跑起来,且相关作者也已不再维护。最后发现了一个基于pytorch的LaneNet的源码,并在其基础上进行了修改。

https://github.com/IrohXu/lanenet-lane-detection-pytorch

从github上获取到源代码后,通过python test.py --img ./data/tusimple_test_image/0.jpg 进行测试。

可以读取文件夹下的图片进行预测。

import argparse
import time
import os
import sys

import torch
from dataloader.transformers import Rescale
from model.lanenet.LaneNet import LaneNet
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms
import numpy as np
from PIL import Image
import pandas as pd
import cv2

# GPU or CPU
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_test_data(img_path, transform):
    img = Image.open(img_path)
    img = transform(img)
    return img


def test():
    # 创建test_output文件夹
    if os.path.exists('test_output') == False:
        os.mkdir('test_output')
    args = parse_args()
    # input图片地址
    img_path = args.img

    # # resize后的图片大小
    # resize_height = args.height
    # resize_width = args.width

    # 图像处理
    data_transform = transforms.Compose([
        # transforms.Resize((resize_height, resize_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 模型参数
    model_path = args.model
    # 模型结构
    model = LaneNet(arch=args.model_type)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(DEVICE)
    print('模型加载成功,开始对图片进行预测')
    # 不计算参数梯度,防止出现因图片size问题出现CUDA out of memory
    with torch.no_grad():
        imgs = os.listdir(img_path)
        imgs.sort(key=lambda x: int(x.split('.')[0]))
        for img in imgs:
            dummy_input = load_test_data(img_path + '/' + img, data_transform).to(DEVICE)
            dummy_input = torch.unsqueeze(dummy_input, dim=0)
            outputs = model(dummy_input)

            # input = Image.open(img_path)
            # input = input.resize((resize_width, resize_height))
            # input = np.array(input)

            instance_pred = torch.squeeze(outputs['instance_seg_logits'].detach().to('cpu')).numpy() * 255
            binary_pred = torch.squeeze(outputs['binary_seg_pred']).to('cpu').numpy() * 255
            # # 保存输入图片
            # cv2.imwrite(os.path.join('test_output', 'input.jpg'), input)

            # cv2.imwrite(os.path.join('test_output', img.split('.')[0] + '_instance.jpg'),
            #             instance_pred.transpose((1, 2, 0)))
            # 保存二值图
            cv2.imwrite(os.path.join('test_output', img.split('.')[0] + '_binary.jpg'), binary_pred)
            print(img.split('.'))

    print('over')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img", default="./video2imgs", help="Img path")
    parser.add_argument("--model_type", help="Model type", default='ENet')
    parser.add_argument("--model", help="Model path", default='./log/best_model.pth')
    parser.add_argument("--width", required=False, type=int, help="Resize width", default=512)
    parser.add_argument("--height", required=False, type=int, help="Resize height", default=256)
    parser.add_argument("--save", help="Directory to save output", default="./test_output")
    return parser.parse_args()


# def img2video(input_root, output_root):
#     img_root = input_root  # 读取图片目录
#     fps = 30  # 保存视频的FPS,可以适当调整
# 
#     # 编码器 可以用(*'DVIX')或(*'X264'),如果都不行先装ffmepg: sudo apt-get install ffmepg
#     fourcc = cv2.VideoWriter_fourcc(*'XVID')
# 
#     videoWriter = cv2.VideoWriter(output_root + '/predict.mp4', fourcc, fps,
#                                   (1280, 720))  # 视频写入;编码器;fps;图片的尺寸,根据自己的图片决定
#     # 遍历文件夹下所有图片,listdir为随机排序
#     imgnames = os.listdir(img_root)
#     # 将图片顺序排序
#     imgnames.sort(key=lambda x: int(x[:-4]))
#     for imgname in imgnames:
#         print(imgname)
#         # 读取图片
#         frame = cv2.imread(img_root + '/' + imgname)
#         videoWriter.write(frame)
#     videoWriter.release()
#     print("已经转为视频")


if __name__ == "__main__":
    test()
    # img2video('D:/lanenet-lane-detection-pytorch-main/test_output', 'D:/lanenet-lane-detection-pytorch-main/imgs2video')

默认input路径下全为图片,没有对路径下的文件进行判断是否为图片,有待完善 

  • 1
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
【资源说明】 1、该资源内项目代码都是经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载使用,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能。 环境部署 (1)我的环境配置 ``` 操作系统:Ubuntu20.04 IDE:vscode Python: 3.6.13 PyTorch: 1.10.2+cu113 CUDA:113 GPU:NVIDIA GeForce RTX 3090 ``` (2)完整的安装脚本 # Linux ​ 这里便是一个完整安装 MMSegmentation 的脚本,使用 conda 并链接了数据集的路径(以您的数据集路径为 $DATA_ROOT 来安装)。 ```shell conda create -n open-mmlab python=3.10 -y conda activate open-mmlab conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html csdn下载解压资源,命名为mmsegmentation cd mmsegmentation pip install -e . # 或者 "python setup.py develop" mkdir data ln -s $DATA_ROOT data ``` # Windows (有风险) ​ 这里便是一个完整安装 MMSegmentation 的脚本,使用 conda 并链接了数据集的路径(以您的数据集路径为 %DATA_ROOT% 来安装)。 注意:它必须是一个绝对路径。 ```shell conda create -n open-mmlab python=3.10 -y conda activate open-mmlab conda install pytorch=1.11.0 torchvision cudatoolkit=11.3 -c pytorch set PATH=full\path\to\your\cpp\compiler;%PATH% pip install mmcv csdn下载解压资源,命名为mmsegmentation cd mmsegmentation pip install -e . # 或者 "python setup.py develop" mklink /D data %DATA_ROOT% ``` ## 二.数据集收集以及标注 (1)数据分析 ​ 使用官方提供的视频,每12帧提取1帧,总共提取583张图片,剔除后84张无车线图片,剩余499张数据样本。 ![](https://s2.loli.net/2022/05/21/PcU5Y1tZBa8FLMs.png) ​ 需要标注的数据区域为图片下1/3区域内的车线。过远区域车线不清晰,不利于模型的训练。只标注车行进的主车线。 (2)数据标注 ​ 数据标注我们选择使用labelme。其优势在于我们可以在任意地方使用该 工具。此外,它也可以帮助我们标注图像,不需要在电脑中安装或复制大型数据集。 标注方式:我们选择用多边形(Polygons)进行车线的标注。 ![](https://s2.loli.net/2022/05/21/bgeJK6hQY2R1XjW.png) (3)数据增强 ​ 在深度学习中,数据增强可以在样本数量不足或者样本质量不够好的情况下,提高样本质量,增加训练的数据量,提高模型的泛化能力,增加噪声数据,提升模型的鲁棒性。 ​ 我们对标注好的车线数据进行数据增强,数据增强的同时保留原有标注数据。对每张图片进行4次数据增强,包含改变亮度、加噪声、加随机点、水平翻转4种形式的数据增强,不同形式的数据增强会随机叠加。 ​ 修改DataAugmentforLabelMe.py文件里的数据集路径,运行后即可得到增强的数据集。 ![](https://s2.loli.net/2022/05/21/iIW3VdtZu2fK9w1.png) (4)数据集 ​ 数据集格式选择voc格式,将labelme标定好的json数据转voc格式。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值