当生成式对抗神经网络遇到车道线检测
目前,卷积神经网络已经成功地应用于语义分割任务。然而,有许多问题本质上不是像素分类问题,但仍然经常被表述为语义分割,将像素概率图转换为最终所需的输出。
以车道线检测为例,目前车道线检测的难点为寻找语义上的线,而不是局限于表观存在的线。
但是生成对抗网络 (GAN) 可用于使语义分割网络的输出更真实或更好地保留结构。
一、数据集简介
本项目使用的是21年新出的车道线检测数据集VIL-100
,这是一个包含100个视频,10000帧图像,涵盖10种车道线类型、各种驾驶场景、光照条件和多条车道线实体,同时对视频中的所有车道线提供了高质量的实体级标注。
- 更多介绍请查看官方论文:https://arxiv.org/abs/2108.08482
- 该数据集已上传至AI Studio:https://aistudio.baidu.com/aistudio/datasetdetail/115234
# 数据可视化
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
root = "data/VIL100"
with open("data/VIL100/data/train.txt", "r") as trainList:
trainDatas = trainList.readlines()
print('训练集数据量: {}'.format(len(trainDatas)))
with open("data/VIL100/data/test.txt", "r") as testList:
testDatas = testList.readlines()
print('测试集数据量: {}'.format(len(testDatas)))
# 从训练集中随机抽取一张图像进行可视化
index = random.randint(0, len(trainDatas))
traindata = trainDatas[index].split(" ")
image = cv2.imread(root + traindata[0])
label = cv2.imread(root + traindata[1])
plt.figure(figsize=(10, 10))
plt.imshow(np.hstack([image, label])[:,:,::-1])
plt.show()
训练集数据量: 8000
测试集数据量: 2000
二、数据预处理
数据预处理部分与图像分割的数据处理类似,需要对输入图像进行归一化,并基于飞桨提供的paddle.io.Dataset
基类,实现自定义数据集。
from paddle.io import Dataset
class VILData(Dataset):
def __init__(self, mode='train'):
super(VILData, self).__init__()
self.train_data_paths = self.load_train_data()# 获取训练集
self.test_data_paths = self.load_test_data() # 获取训练集
self.mode = mode
self.root = "data/VIL100"
def __getitem__(self, idx):
if self.mode == 'test':
data_paths = self.test_data_paths
else:
data_paths = self.train_data_paths
image = cv2.imread(self.root + data_paths[idx].split(" ")[0])
image = (image / 255. * 2. - 1.).astype('float32')
image = np.transpose(image, (2, 0, 1))
label = cv2.imread(self.root + data_paths[idx].split(" ")[1])
label = label.astype('float32')
label = np.transpose(label, (2, 0, 1))
return image, label
def __len__(self):
if self.mode == 'test':
return len(self.test_data_paths)
else:
return len(self.train_data_paths)
@staticmethod
def load_train_data():
data_path = 'data/VIL100/data/train.txt'
with open(data_path, "r") as trainList:
return trainList.readlines()
@staticmethod
def load_test_data():
data_path = 'data/VIL100/data/test.txt'
with open(data_path, "r") as testList:
return testList.readlines()
traindataset = VILData('train')
testdataset = VILData('test')
# 从训练集中随机抽取一张图像进行可视化
index = random.randint(0, len(traindataset))
plt.imshow(np.transpose(traindataset[index][0], (1,2,0)))
plt.show()
# 可视化图像对应的标签
plt.imshow(np.transpose(traindataset[index][1], (1,2,0)))
plt.show()
三、模型组网
Pix2Pix,通过随机向量z和图像x生成需要图像y,即{z,x} -> y
生成器G用于生成尽可能愚弄判别器D的图像,判别器D尽可能分辨出生成器G生成的假图以及真实图像。
1.生成器的搭建
生成器G的结构采用的是U-Net。
在车道线检测任务中