图像分割套件PaddleSeg全面解析(三)DataSet代码解读

在yaml配置文件中,我们配置的train_dataset的type为Cityscapes类型。通过以上Config代码的解读,我们知道了在第一次调用Config对象的train_dataset属性时会懒加载创建Cityscapes对象。
Cityscapes类的位置在paddleseg/datasets/cityscapes.py,Cityscapes的父类为Dataset,位于同目录下的dataset.py文件中,所以我先从Dataset类开始解读。

首先从Dataset的构造函数开始,构造函数比较长,里面包含了一些判断逻辑去初始化成员变量:

def __init__(self,
             transforms,#图像的transform
             dataset_root,#dataset的路劲
             num_classes, #类别数量
             mode='train', # 训练模式,train、val和test
             train_path=None, #训练列表文件路径,文件中每一行第一个是样本文件,第二个是标注文件。image1.jpg ground_truth1.png
             val_path=None, #验证列表文件路径,与训练文件一致。
             test_path=None,#与训练文件一致,其中标注文件不是必须的。
             separator=' ', #指定列表文件中样本文件和训练文件的分隔符,默认是空格
             ignore_index=255, #需要忽略的类别id
             edge=False): #是否在训练时计算边缘
    #保存数据的路径
    self.dataset_root = dataset_root
    #构建数据增强对象
    self.transforms = Compose(transforms)
    #新建一个保存文件路径的空列表
    self.file_list = list()
    #将模式类型字符串转换为小写并保存为成员变量
    mode = mode.lower()
    self.mode = mode
    #保存类别数
    self.num_classes = num_classes
    #保存需要忽略的类别编号,一般都是255
    self.ignore_index = ignore_index
    #保存edge
    self.edge = edge
    
    #如果mode不在train\val\test中,需要抛出异常。
    if mode.lower() not in ['train', 'val', 'test']:
        raise ValueError(
            "mode should be 'train', 'val' or 'test', but got {}.".format(
                mode))
    #数据增强对象必须指定,如果未设置,抛出异常。
    if self.transforms is None:
        raise ValueError("`transforms` is necessary, but it is None.")
    #如果数据集路径不存在则抛出异常。
    self.dataset_root = dataset_root
    if not os.path.exists(self.dataset_root):
        raise FileNotFoundError('there is not `dataset_root`: {}.'.format(
            self.dataset_root))
    #判断各个类型的文件列表是否存在,不存在抛出异常,存在则保存到file_path变量中。
    if mode == 'train':
        if train_path is None:
            raise ValueError(
                'When `mode` is "train", `train_path` is necessary, but it is None.'
            )
        elif not os.path.exists(train_path):
            raise FileNotFoundError(
                '`train_path` is not found: {}'.format(train_path))
        else:
            file_path = train_path
    elif mode == 'val':
        if val_path is None:
            raise ValueError(
                'When `mode` is "val", `val_path` is necessary, but it is None.'
            )
        elif not os.path.exists(val_path):
            raise FileNotFoundError(
                '`val_path` is not found: {}'.format(val_path))
        else:
            file_path = val_path
    else:
        if test_path is None:
            raise ValueError(
                'When `mode` is "test", `test_path` is necessary, but it is None.'
            )
        elif not os.path.exists(test_path):
            raise FileNotFoundError(
                '`test_path` is not found: {}'.format(test_path))
        else:
            file_path = test_path
    #打开列表文件,文件包含若干行,数量与数据集样本数量相同,训练集(train)和验证集(val)列表包含样本路径和标签文件路径。
    #测试集则只包含样本路径。
    with open(file_path, 'r') as f:
        #遍历列表文件中的每一行。
        for line in f:
            #分离样本路径和标签路径。
            items = line.strip().split(separator)
            #如果在训练集和验证集不包含样本路径和标签路径则抛出异常。
            if len(items) != 2:
                if mode == 'train' or mode == 'val':
                    raise ValueError(
                        "File list format incorrect! In training or evaluation task it should be"
                        " image_name{}label_name\\n".format(separator))
                image_path = os.path.join(self.dataset_root, items[0])
                label_path = None
            else:
                #拼接样本完整路径和标签完整路径
                image_path = os.path.join(self.dataset_root, items[0])
                label_path = os.path.join(self.dataset_root, items[1])
            #将样本路径和标签路径保存在列表中。
            self.file_list.append([image_path, label_path])

凡是在类中定义了这个__getitem__ 方法,那么它的实例对象(假定为p),可以像这样p[key] 取值,当实例对象做p[key] 运算时,会调用类中的方法__getitem__。
这样对象就可通过下标进行查找对象。对象就可以成为一个可迭代对象。

下面解读在Dataset类中,如何通过file_list返回样本和标签。

def __getitem__(self, idx):
	  #通过idx下标,在file_list里获取样本图片路径和标签图片路径。
      image_path, label_path = self.file_list[idx]
      #如果是测试模式则返回图片ndarray类型的数据。在transforms中,包含了图片的读取和预处理,不同模式的dataset类的transforms对象是不同的。
      if self.mode == 'test':
            im, _ = self.transforms(im=image_path)
            im = im[np.newaxis, ...]
            return im, image_path
      #如果是训练或者验证模式还需要返回样本图片的和标签图片的ndarray的数据类型。
      elif self.mode == 'val':
            im, _ = self.transforms(im=image_path)
            label = np.asarray(Image.open(label_path))
            label = label[np.newaxis, :, :]
            return im, label
      else:
            im, label = self.transforms(im=image_path, label=label_path)
            if self.edge:
                edge_mask = F.mask_to_binary_edge(
                    label, radius=2, num_classes=self.num_classes)
                return im, label, edge_mask
            else:
                return im, label

在类中定义了__len__方法,可以使用len函数来获得长度,在Dataset类中保存文件列表,所以需要通过len函数来获取数据集中样本的数量,所以在Dataset类中还需要实现__len__方法。

def __len__(self):
      #该方法直接返回file_list列表的长度即可。
      return len(self.file_list)

上面解读了Dataset类的实现,下面我们在来看看一个实际的数据集Cityscapes。
Cityscapes类定义dygraph/paddleseg/datasets/cityscapes.py文件中,该类是Dataset的子类。自然它继承了__getitem__和__len__方法,这两个方法
中的代码是可复用的。在__getitem__中包含了对样本图片和标签图片的预处理,这部分不论是什么数据集操作应该类似的,及时在预处理有不同的地方也可以通过传递
transforms对象来处理,所以在Cityscapes类中,我们只关心构造函数即可。

def __init__(self, transforms, dataset_root, mode='train', edge=False):
      #这部分与Dataset类基本一致,保存一些成员变量,不过这里面指定了该数据集共有19类,同时直接指定了ignore_index为255.
      self.dataset_root = dataset_root
      self.transforms = Compose(transforms)
      self.file_list = list()
      mode = mode.lower()
      self.mode = mode
      self.num_classes = 19
      self.ignore_index = 255
      self.edge = edge

      if mode not in ['train', 'val', 'test']:
          raise ValueError(
              "mode should be 'train', 'val' or 'test', but got {}.".format(
                  mode))

      if self.transforms is None:
          raise ValueError("`transforms` is necessary, but it is None.")
      #由于不同的数据集文件组织结构会不同,在Cityscapes数据集中样本图片和标签图片分别保存在leftImg8bit和gtFine路径下。
      img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
      label_dir = os.path.join(self.dataset_root, 'gtFine')
      if self.dataset_root is None or not os.path.isdir(
              self.dataset_root) or not os.path.isdir(
                  img_dir) or not os.path.isdir(label_dir):
          raise ValueError(
              "The dataset is not Found or the folder structure is nonconfoumance."
          )
      #这里没有使用读取列表文件的方式获取样本图片列表和标签图片列表,而是通过glob方法使用正则化的方法匹配对应的文件来获取标签图片路径。
      label_files = sorted(
          glob.glob(
              os.path.join(label_dir, mode, '*',
                           '*_gtFine_labelTrainIds.png')))
      #跟上面一样获取样本图片路径列表。
      img_files = sorted(
          glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
      #构建文件列表,每一个元素,是包含两个元素的列表,形式为[样本图片路径,标签图片路径],供父类的__getitem__调用去预处理图片数据。
      self.file_list = [[
          img_path, label_path
      ] for img_path, label_path in zip(img_files, label_files)]

数据集部分的主要代码基本解读完毕,这里我们了解到在PaddleSeg套件中已经提供了Dataset类作为基类,所以如果我们想添加新的数据集则可以继承
Dataset类,然后自己实现__init__构造方法即可,可参考Cityscapes类实现。

PaddleSeg仓库地址:https://github.com/PaddlePaddle/PaddleSeg

  • 3
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

人工智能研习社

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值