PaddleDetection-MaskRcnn读入程序解析

2021SC@SDUSC

ppdet/data/reader.py源码分析

首先是在yaml上的配置:

文件./_base_/datasets/coco.yml

'''

metric: COCO      # 验证模型的评测标准,可以选择COCO或者VOC
                  # 用于训练或验证的数据集的类别数目,注意这里不含背景类
                  # RCNN系列中包含背景类,即81=80 + 1(背景类)
num_classes: 80   #类别数量



TrainDataset:                             #训练数据
  !COCODataSet                            #COCO数据集
    image_dir: train2017                  # 图片文件夹相对路径,路径是相对于dataset_dir,图像路径= dataset_dir + image_dir + image_name
    anno_path: annotations/instances_train2017.json # anno_path,路径是相对于dataset_dir
    dataset_dir: dataset/coco             # 数据集相对路径,路径是相对于PaddleDetection

EvalDataset:                             #验证数据
  !COCODataSet                           #COCO数据集
    image_dir: val2017                   #图片文件夹相对路径,路径是相对于dataset_dir,图像路径= dataset_dir + image_dir + image_name
    anno_path: annotations/instances_val2017.json  #标签目录,路径是相对于dataset_dir
    dataset_dir: dataset/coco            #数据集相对路径,路径是相对于PaddleDetection

TestDataset:                             #测试数据
  !ImageFolder                           
    anno_path: annotations/instances_val2017.json  #标签目录,路径是相对于dataset_dir
'''

然后是./_base_/readers/mask_fpn_reader.yml #流程基本都是相同的,数据处理会根据算法 相应的做一些调整

'''
worker_num: 2   #数据读取线程数
TrainReader: # 训练过程中模型的输入设置
  sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
  - DecodeOp: {}
  - RandomFlipImage: {prob: 0.5, is_mask_flip: true}
  - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
  - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
  - Permute: {to_bgr: false, channel_first: true}
  batch_transforms:#数据批处理
  - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true}
  batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
  shuffle: true  #数据是否随机
  drop_last: true #是否丢弃最后与设置维度不匹配的数据 # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true


EvalReader: #验证数据读取
  sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
  - DecodeOp: {}
  - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
  - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
  - PermuteOp: {}
  batch_transforms: #数据批处理
  - PadBatchOp: {pad_to_stride: 32, pad_gt: false}
  batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
  shuffle: false  #数据是否随机
  drop_last: false # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true
  drop_empty: false  #丢弃空数据


TestReader: #测试数据读取,有些前处理需要保持一致
  sample_transforms: #单张图片数据前处理,数据增强,下面是各种数据增强方法,放入列表中
  - DecodeOp: {}
  - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
  - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
  - PermuteOp: {}
  batch_transforms: #数据批处理
  - PadBatchOp: {pad_to_stride: 32, pad_gt: false}
  batch_size: 1 # 1个GPU的batch size,默认为1。需要注意:每个iter迭代会运行batch_size * device_num张图片
  shuffle: false #数据是否随机
  drop_last: false #是否丢弃最后与设置维度不匹配的数据 # 注意,在某些情况下,drop_last=false时训练过程中可能会出错,建议训练时都设置为true
'''

引用相关库: 

import copy
import traceback
import six
import sys
import multiprocessing as mp
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np

from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler

from ppdet.core.workspace import register, serializable, create
from . import transform

from ppdet.utils.logger import setup_logger
logger = setup_logger('reader')
class Compose(object):
    def __init__(self, transforms, num_classes=81):
        self.transforms = transforms
        self.transforms_cls = []
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(transform, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

    def __call__(self, data):
        for f in self.transforms_cls:
            try:
                data = f(data)
            except Exception as e:
                stack_info = traceback.format_exc()
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
                            format(f, e, str(stack_info)))
                raise e

        return data

该类是单张图片数据增强类,多种单张数据增强方式都在transforms列表中 ,通过遍历该列表对图片进行多种数据增强最后返回增强后的结果。

class BatchCompose(Compose):
    def __init__(self, transforms, num_classes=81):
        super(BatchCompose, self).__init__(transforms, num_classes)
        self.output_fields = mp.Manager().list([])
        self.lock = mp.Lock()

    def __call__(self, data):
        for f in self.transforms_cls:
            try:
                data = f(data)
            except Exception as e:
                stack_info = traceback.format_exc()
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
                            format(f, e, str(stack_info)))
                raise e

        # parse output fields by first sample
        # **this shoule be fixed if paddle.io.DataLoader support**
        # For paddle.io.DataLoader not support dict currently,
        # we need to parse the key from the first sample,
        # BatchCompose.__call__ will be called in each worker
        # process, so lock is need here.
        if len(self.output_fields) == 0:
            self.lock.acquire()
            if len(self.output_fields) == 0:
                for k, v in data[0].items():
                    # FIXME(dkp): for more elegent coding
                    if k not in ['flipped', 'h', 'w']:
                        self.output_fields.append(k)
            self.lock.release()

        data = [[data[i][k] for k in self.output_fields]
                for i in range(len(data))]
        data = list(zip(*data))

        batch_data = [np.stack(d, axis=0) for d in data]
        return batch_data

 此类为批量图片数据增强类 , 同Compose,这里是对批量数据进行增强

class BaseDataLoader(object):
    __share__ = ['num_classes']

    def __init__(self,
                 inputs_def=None,
                 sample_transforms=[],
                 batch_transforms=[],
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True,
                 **kwargs):
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)

        # batch transfrom 
        self._batch_transforms = BatchCompose(batch_transforms, num_classes)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background
        self.kwargs = kwargs

    def __call__(self,
                 dataset,
                 worker_num,
                 batch_sampler=None,
                 return_list=False,
                 use_prefetch=True):
        self.dataset = dataset
        self.dataset.parse_dataset(self.with_background)
        # get data
        self.dataset.set_transform(self._sample_transforms)
        # set kwargs
        self.dataset.set_kwargs(**self.kwargs)
        # batch sampler
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
                self.dataset,
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler

        self.dataloader = DataLoader(
            dataset=self.dataset,
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)
        self.loader = iter(self.dataloader)

        return self

    def __len__(self):
        return len(self._batch_sampler)

    def __iter__(self):
        return self

    def __next__(self):
        # pack {filed_name: field_data} here
        # looking forward to support dictionary
        # data structure in paddle.io.DataLoader
        try:
            data = next(self.loader)
            return {
                k: v
                for k, v in zip(self._batch_transforms.output_fields, data)
            }
        except StopIteration:
            self.loader = iter(self.dataloader)
            six.reraise(*sys.exc_info())

    def next(self):
        # python2 compatibility
        return self.__next__()

 该类为数据加载基类 ,调用Compose和BatchCompose中的方法进行数据增强,最后迭代输出数据(通过call调用)。

'''
@register
@serializable
class IouLoss(object):


    def __init__(self,):


    def __call__(self, s):
        
        return s
'''

 @的含义: #Python当解释器读到@的这样的修饰符之后,会先解析@后的内容, 直接就把@下一行的函数或者类作为@后边的函数的参数, 然后将返回值赋值给下一行修饰的函数对象。

@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=[],
                 batch_transforms=[],
                 batch_size=1,
                 shuffle=True,
                 drop_last=True,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True,
                 **kwargs):
        super(TrainReader, self).__init__(inputs_def, sample_transforms,
                                          batch_transforms, batch_size, shuffle,
                                          drop_last, drop_empty, num_classes,
                                          with_background, **kwargs)

训练数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去 ,例如TrainDataset的配置如下,那TrainReader就会传入下面所有的参数,其中的成员也会传入相应的参数实例相应的类,实现TrainReader。

@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=[],
                 batch_transforms=[],
                 batch_size=1,
                 shuffle=False,
                 drop_last=True,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True,
                 **kwargs):
        super(EvalReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)

验证数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去。 

@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=[],
                 batch_transforms=[],
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True,
                 **kwargs):
        super(TestReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)

 测试数据加载类,会将yaml文件对应的参数传入相应的类名中,包含的成员也会将成员类中的参数传入相应的类中去。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值