【Pix2Pix】当生成式对抗神经网络遇到车道线检测

目前,卷积神经网络已经成功地应用于语义分割任务。然而,有许多问题本质上不是像素分类问题,但仍然经常被表述为语义分割,将像素概率图转换为最终所需的输出。

以车道线检测为例,目前车道线检测的难点为寻找语义上的线,而不是局限于表观存在的线。

但是生成对抗网络 (GAN) 可用于使语义分割网络的输出更真实或更好地保留结构。

一、数据集简介

本项目使用的是21年新出的车道线检测数据集VIL-100,这是一个包含100个视频,10000帧图像,涵盖10种车道线类型、各种驾驶场景、光照条件和多条车道线实体,同时对视频中的所有车道线提供了高质量的实体级标注。

# 数据可视化
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

root = "data/VIL100"
with open("data/VIL100/data/train.txt", "r") as trainList: 
    trainDatas = trainList.readlines()
    print('训练集数据量: {}'.format(len(trainDatas)))

with open("data/VIL100/data/test.txt", "r") as testList: 
    testDatas = testList.readlines()
    print('测试集数据量: {}'.format(len(testDatas)))

# 从训练集中随机抽取一张图像进行可视化
index = random.randint(0, len(trainDatas))
traindata = trainDatas[index].split(" ")
image = cv2.imread(root + traindata[0])
label = cv2.imread(root + traindata[1])
plt.figure(figsize=(10, 10))
plt.imshow(np.hstack([image, label])[:,:,::-1])
plt.show()
训练集数据量: 8000
测试集数据量: 2000

二、数据预处理

数据预处理部分与图像分割的数据处理类似,需要对输入图像进行归一化,并基于飞桨提供的paddle.io.Dataset基类,实现自定义数据集。

from paddle.io import Dataset

class VILData(Dataset):
    def __init__(self, mode='train'):
        super(VILData, self).__init__() 
        self.train_data_paths = self.load_train_data()# 获取训练集
        self.test_data_paths = self.load_test_data()  # 获取训练集
        self.mode = mode
        self.root = "data/VIL100"

    def __getitem__(self, idx):
        if self.mode == 'test':
            data_paths = self.test_data_paths
        else:
            data_paths = self.train_data_paths
        image = cv2.imread(self.root + data_paths[idx].split(" ")[0])
        image = (image / 255. * 2. - 1.).astype('float32')
        image = np.transpose(image, (2, 0, 1))
        label = cv2.imread(self.root + data_paths[idx].split(" ")[1])
        label = label.astype('float32')
        label = np.transpose(label, (2, 0, 1))

        return image, label

    def __len__(self):
        if self.mode == 'test':
            return len(self.test_data_paths)
        else:
            return len(self.train_data_paths)

    @staticmethod
    def load_train_data():
        data_path = 'data/VIL100/data/train.txt'
        with open(data_path, "r") as trainList: 
            return trainList.readlines()

    @staticmethod
    def load_test_data():
        data_path = 'data/VIL100/data/test.txt'
        with open(data_path, "r") as testList: 
            return testList.readlines()

traindataset = VILData('train')
testdataset = VILData('test')
# 从训练集中随机抽取一张图像进行可视化
index = random.randint(0, len(traindataset))
plt.imshow(np.transpose(traindataset[index][0], (1,2,0)))
plt.show()

在这里插入图片描述

# 可视化图像对应的标签
plt.imshow(np.transpose(traindataset[index][1], (1,2,0)))
plt.show()

在这里插入图片描述

三、模型组网

Pix2Pix,通过随机向量z和图像x生成需要图像y,即{z,x} -> y

生成器G用于生成尽可能愚弄判别器D的图像,判别器D尽可能分辨出生成器G生成的假图以及真实图像。

1.生成器的搭建

生成器G的结构采用的是U-Net。

在车道线检测任务中

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.郑先生_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值