车道线分割项目记录-tusimple数据集处理

一、数据集包含信息

该项目训练所使用的数据集是tusimple数据集,其中用于训练及验证的有约3500张图,测试的有2000多张图。数据集中,除了图片,还包含了json文件,携带了车道线信息、文件路径。每一条数据如下所示:

{"lanes": [[-2, -2, -2, 348, 358, 357, 352, 347, 341, 331, 316, 301, 286, 271, 256, 241, 226, 211, 196, 182, 167, 152, 137, 122, 107, 92, 77, 62, 47, 32, 17, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, 427, 451, 469, 487, 504, 520, 526, 533, 539, 545, 551, 557, 564, 570, 576, 582, 588, 595, 601, 607, 613, 619, 626, 632, 638, 644, 650, 657, 663, 669, 675, 681, 688, 694, 700, 706, 712, 719, 725, 731, 737, 743, 750, -2, -2], [-2, -2, -2, 274, 263, 253, 227, 200, 173, 146, 118, 91, 64, 37, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, 552, 601, 649, 690, 721, 751, 782, 813, 844, 874, 905, 936, 967, 998, 1028, 1059, 1090, 1121, 1151, 1182, 1213, 1244, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], "h_samples": [240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710], "raw_file": "clips/0313-2/35080/20.jpg"}

可以看到,这里面有"lanes", "h_samples", "raw_file"三部分数据。

其中,"lanes"记录的是车道线的横坐标,就是图片的宽的坐标,"h_samples"记录的是纵坐标,就是图片的高的坐标,"raw_file"记录的是文件路径。

车道线可能有多条,比如上面这个例子里面就有4条,那么横坐标就有4组,纵坐标有1组,因此横纵坐标合在一起,构成了车道线所对应的点。

 

二、数据集创建

1.创建标签

标签只有车道线的几个点,但是模型需要的,一个是二值的语义分割标签,里面1的地方是车道线,其他地方是0,另一个是实例分割标签,比如不是车道线的地方是0,第一条是1,第二条是2等等。

这里,我创建语义分割标签图的时候,是通过cv2.line,把前后点连接起来,设定一个宽度,使标签变为以下这样:

 同样,创建实例分割标签的时候,方式一致,但是填的不全是1,而是1,2,3,4,就像下面这样(示例):

注意:创建完标签,可以保存为图像,保存为png格式,读取的时候,使用cv2.imread的时候,注意第二个参数传-1,cv2.imread(img_path, -1),就可以按照存的方式读取了。否则你如果保存的是二值图,读取的时候可能会变成三通道的。代码如下

def get_img_path_lanes(json_path):
    path_lane_data = []
    for file_path in json_path:
        with open(file_path,'r',encoding='utf-8') as f:
            data = f.readlines()
            for line in data:
                dicts = json.loads(line)
                lane_xy = []
                for lane in dicts['lanes']:
                    y = np.array([dicts['h_samples']]).T
                    x = np.array([lane]).T
                    lane_xy.append(np.hstack((x,y)))
                path_lane_data.append([dicts['raw_file'],lane_xy])
    return path_lane_data

def generate_labels():
    train_list = ['./data/train_set/label_data_0313.json', './data/train_set/label_data_0601.json']
    val_list = ['./data/train_set/label_data_0531.json']
    test_list = ['./data/test_set/test_tasks_0627.json']
    path_lane_train = get_img_path_lanes(train_list)
    path_lane_val = get_img_path_lanes(val_list)
    path_lane_test = get_img_path_lanes(test_list)
    paths = [path_lane_train, path_lane_val]
    label_folder = 'seg_label'

    # 创建语义分割标签图并保存
    if not os.path.exists('./seg_label/0313-1/6040'):
        print('未找到语义分割标签图路径,正在处理......')
        for one_path in paths:
            for item in one_path:
                item_save_path = label_folder+'/'+'/'.join(item[0].split('/')[1:3])
                # 找到并读取图片
                path = os.path.join('./data/train_set/', item[0])
                img = cv2.imread(path)
                # 获取高和宽
                H,W = img.shape[0], img.shape[1]
                # 创建标签图,全0
                mask = np.zeros((H,W))
                for k,lane in enumerate(item[1]):
                    # 选择第一张图的车道线来画,就画4根
                    if k == 4:
                        continue
                    for j in range(len(lane)-1):
                        if lane[j][0] != -2 and lane[j+1][0] != -2:
                            cv2.line(mask, tuple(lane[j]),tuple(lane[j+1]),1,8)
                    # cv_show(mask)
                # 存储Mask
                if not os.path.exists(item_save_path):
                    os.makedirs(item_save_path)
                cv2.imwrite(item_save_path+'/'+(item[0].split('/')[3]).split('.')[0]+'.png',mask)
        print('训练+验证集语义分割标签图保存完毕!')
    # 创建实例分割标签图并保存
    if not os.path.exists('./seg_label/instance_div/0313-1/6040'):
        print('未找到实例分割标签图路径,正在处理......')
        for one_path in paths:
            for item in one_path:
                item_save_path = label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])
                # 找到并读取图片
                path = os.path.join('./data/train_set/', item[0])
                img = cv2.imread(path)
                # 获取高和宽
                H,W = img.shape[0], img.shape[1]
                # 创建标签图,全0
                mask = np.zeros((H,W))
                for k,lane in enumerate(item[1]):
                    # 选择第一张图的车道线来画,就画4根
                    if k == 4:
                        continue
                    for j in range(len(lane)-1):
                        if lane[j][0] != -2 and lane[j+1][0] != -2:
                            cv2.line(mask, tuple(lane[j]),tuple(lane[j+1]),k+1,8)
                    # cv_show(mask)
                # 存储Mask
                if not os.path.exists(item_save_path):
                    os.makedirs(item_save_path)
                cv2.imwrite(item_save_path+'/'+(item[0].split('/')[3]).split('.')[0]+'.png',mask)
        print('训练+验证集实例分割标签图保存完毕!')

    # 创建txt文件路径并保存
    if not os.path.exists('./seg_label/train.txt'):
        print('未找到图片及标签汇总文件路径,正在处理......')
        file_paths = ['./seg_label/train.txt','./seg_label/val.txt','./seg_label/test.txt']
        for i,one_path in enumerate([path_lane_train,path_lane_val,path_lane_test]):
            if i == 0:
                with open(file_paths[i],'w',encoding='utf-8') as f:
                    for item in one_path:
                        num = '/'+(item[0].split('/')[3]).split('.')[0]
                        f.write('data/train_set/'+item[0]+' '+label_folder+'/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+' '
                                + label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+'\n')
            if i == 1:
                with open(file_paths[i],'w',encoding='utf-8') as f:
                    for item in one_path:
                        num = '/'+(item[0].split('/')[3]).split('.')[0]
                        f.write('data/train_set/'+item[0]+' '+label_folder+'/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+' '
                                + label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+'\n')
            if i == 2:
                with open(file_paths[i],'w',encoding='utf-8') as f:
                    for item in one_path:
                        f.write('data/test_set/'+item[0]+'\n')
        print('图片、语义分割标签、实例分割标签路径保存完毕!')

2.创建dataset类

这里还是创建的时候继承自torch.utils.data.Dataset类就可以,在类中实现__getitem__以及__len__,这里由于训练数据有3000多张,一次性导入内存的话,电脑会卡死,因此我改变了导入方式,只导入路径,__len__这里返回的也是路径列表,而在__getitem__里面,则是去索引路径信息,传入到一个专门读取并处理图片的函数中就可以了,处理的时候增加一些随机性,这里我加了亮度对比度的随机调整,随机旋转。

class LaneDataset(Dataset):
    '''
    输入训练数据路径、验证数据路径、测试数据路径,初始化的时候会自动创建索引文件,根据输入的mode不同,返回不同的处理好的样本。
    如果是train或者val,返回的是处理过的图片+处理过的标签(2个标签,一个是语义分割,一个是实例分割)
    如果是test,返回的是处理过的图片
    通过DataLoader之后,取出的是一个list,第一个元素是(B,3,H,W),第二个元素是List,有B个元素,每一个都是包含(2,H,W)的标签
    '''
    def __init__(self,resize_shape=(640,360), transform=None,rotate_theta=2, mode='train'):
        super(LaneDataset, self).__init__()

        self.transforms = transform
        self.mode = mode
        self.resize_shape = resize_shape
        self.rotate_theta=rotate_theta

        prepared_file_paths = ['./seg_label/train.txt','./seg_label/instance_div','./seg_label/0313-1']
        for prepared in prepared_file_paths:
            if not os.path.exists(prepared):
                print('预备文件路径缺少{},开始准备...'.format(prepared))
                generate_labels()
        self.data_list = self.get_path()[:800]

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        if self.mode == 'train' or self.mode == 'val':
            processed_img, processed_labels = self.preprocess_data(self.data_list[idx])
            return processed_img,processed_labels
        elif self.mode=='test':
            processed_img = self.preprocess_data(self.data_list[idx])
            return processed_img

    def get_path(self):
        if self.mode not in ['train','val','test']:
            raise Exception('数据应当是train, val, test三者之一')
        # 根据mode取出标签
        modes = ['train','val']
        if self.mode in modes:
            data_list = []
            with open('./seg_label/{}.txt'.format(self.mode), 'r', encoding='utf-8') as f:
                all_paths = f.readlines()
                for path in all_paths:
                    # 每一个path.split,都是[“图片路径”,“语义分割标签路径”,“实例分割标签路径”]
                    data_list.append(path.strip().split())
        else:
            data_list = []
            with open('./seg_label/test.txt', 'r', encoding='utf-8') as f:
                all_paths = f.readlines()
                for path in all_paths:
                    data_list.append(path.strip())
        return data_list

    def preprocess_data(self, data_list):
        '''
        :return: 如果是train和val,就返回图片+标签数据,如果是test,就只返回测试图片
        '''
        if self.mode not in ['train','val','test']:
            raise Exception('数据应当是train, val, test三者之一')
        # 根据mode取出标签
        modes = ['train','val']
        if self.mode in modes:
            label_list = []
            img = cv2.imread(data_list[0])
            label_list.append(cv2.imread(data_list[1],-1))
            label_list.append(cv2.imread(data_list[2],-1))
            # 使用图本身的均值方差进行标准化,这块效果不好,还是使用网上通用的均值和方差吧。
            # 需要注意的是,如果以下面的方式操作,数据类型会发生改变,导致在dataloader拿数据
            # 的时候,出现错误,因此一定要记得把数据类型还原回uint8
            mean,std = cv2.meanStdDev(img)
            b,g,r = cv2.split(img)
            b1 = (b - mean[0]) / (1.e-6 + std[0])
            g1 = (g - mean[1]) / (1.e-6 + std[1])
            r1 = (r - mean[2]) / (1.e-6 + std[2])
            img = cv2.merge([b1, g1, r1]).astype('uint8')

            # # 转成RGB
            # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # 预处理:亮度对比度
            img = self.bright_contra_adjust(img)
            # resize
            img = cv2.resize(img, self.resize_shape, interpolation=cv2.INTER_CUBIC)
            label1 = cv2.resize(label_list[0], self.resize_shape, interpolation=cv2.INTER_NEAREST)
            label2 = cv2.resize(label_list[1], self.resize_shape, interpolation=cv2.INTER_NEAREST)
            # Rotation
            u = np.random.uniform()
            degree = (u-0.5) * self.rotate_theta
            R = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2),degree,1)
            img_rotate = cv2.warpAffine(img,R,(img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR)
            label1_rotate = cv2.warpAffine(label1,R,(label1.shape[1], label1.shape[0]), flags=cv2.INTER_NEAREST)
            label2_rotate = cv2.warpAffine(label2, R, (label2.shape[1], label2.shape[0]), flags=cv2.INTER_NEAREST)
            # transform
            img = self.transforms(img_rotate)
            return img, [label1_rotate, label2_rotate]
        else:
            img = cv2.imread(data_list)
            mean,std = cv2.meanStdDev(img)
            b,g,r = cv2.split(img)
            b1 = (b - mean[0]) / (1.e-6 + std[0])
            g1 = (g - mean[1]) / (1.e-6 + std[1])
            r1 = (r - mean[2]) / (1.e-6 + std[2])
            img = cv2.merge([b1, g1, r1]).astype('uint8')

            # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, self.resize_shape, interpolation=cv2.INTER_CUBIC)
            img = self.transforms(img)
            return img

    def bright_contra_adjust(self, img):
        '''亮度对比度调整,随机增加或减少0-10'''
        contra = random.uniform(0.85,1.15)
        bright = random.randint(-30,20)
        if random.uniform(0,1) > 0.5:
            return img
        else:
            img = img.astype(np.int)
            return np.uint8(np.clip(img*contra+ bright, 0, 255))
  • 0
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值