[拆轮子] PaddleDetection 中的 COCODataSet 是怎么写的

今日,辗转反侧,该💩的代码就是跑不成功,来看看 COCODataSet 到底是怎么写的,本文只参考当前版本的代码,当前版本 PaddleDetection2.5 COCODataSet 源码见本文附录,(本文适用于有一定Python基础的童鞋看)

COCODataSet 类内部就三个函数:

__init__
parse_dataset  
_sample_empty     # 该函数供 parse_dataset 调用

来看一下 COCODataSet 的基类实现函数,咱挨个看

__init__
__len__
__call__
__getitem__
check_or_download_dataset
set_kwargs
set_transform
set_epoch
parse_dataset
get_anno
1. 基类parse_dataset
def parse_dataset(self, ):
    raise NotImplementedError(
        "Need to implement parse_dataset method of Dataset")

该类必须要被继承之后实现该方法,继承该类中必须解析数据集,并将数据集中的内容传给变量 self.roidbs,具体内容之后看, self.roidbs 变量是一个列表,每一项都是一张照片的内容

parse_dataset 唯一要做的一件事就是解析数据并传给变量 self.roidbs

self.roidbs 中一个 item 是:

{'gt_bbox': array([[133.51,  24.77, 366.11, 562.92]], dtype=float32),
 'gt_class': array([[14]], dtype=int32),
 'h': 640.0,
 'im_file': 'dataset/coco/COCO/val2017/000000270705.jpg',
 'im_id': array([270705]),
 'is_crowd': array([[0]], dtype=int32),
 'w': 475.0}
2. 基类__len__
def __len__(self, ):
    return len(self.roidbs) * self.repeat

len(self.roidbs) 就是原始数据的内容,self.repeat 是重复次数,所以在__getitem__ 有这么一句:

if self.repeat > 1:
    idx %= n

用来进行重复操作

3. 基类__call__
def __call__(self, *args, **kwargs):
    return self

做这个操作其实没啥说的了,实例化之后call一下还是返回自己

4. 基类其他不重要函数
  • 设置部分,用来设置自身的属性,基本没被调用
def set_kwargs(self, **kwargs):
    self.mixup_epoch = kwargs.get('mixup_epoch', -1)
    self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
    self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)

def set_transform(self, transform):
    self.transform = transform

def set_epoch(self, epoch_id):
    self._epoch = epoch_id
  • 获取部分:
def get_anno(self):
    if self.anno_path is None:
        return
    return os.path.join(self.dataset_dir, self.anno_path)

获取标注 ann.json 的路径

  • 检查数据路径函数,也没被调用,不重要跳过
def check_or_download_dataset(self):
    self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
                                        self.image_dir)

以上函数供 read 在 dataset类 外部调用(之后会讲到)
在这里插入图片描述

所以 self.mixup_epoch , self.cutmix_epochself.mosaic_epoch 默认值都是 -1

5. 基类 __getitem__ 函数
def __getitem__(self, idx):
	
	# ------- 用来进行重复操作的部分 -------
    n = len(self.roidbs)
    if self.repeat > 1:
        idx %= n


    # ------- 深拷贝当前的数据项 -------
    roidb = copy.deepcopy(self.roidbs[idx])
    # 以下仨 if 和数据增强有关
    if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
        idx = np.random.randint(n)
        roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
    elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
        idx = np.random.randint(n)
        roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
    elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
        roidb = [roidb, ] + [
            copy.deepcopy(self.roidbs[np.random.randint(n)])
            for _ in range(4)
        ]
    
	# ------- 设置 curr_iter -------
    if isinstance(roidb, Sequence):
        for r in roidb:
            r['curr_iter'] = self._curr_iter
    else:
        roidb['curr_iter'] = self._curr_iter
    self._curr_iter += 1
    
    # ------- 对当前数据项进行之前的 transform ------- 
    return self.transform(roidb)

注意注意注意!!!
roidb['gt_bbox'] 返回的是 x 1 y 1 x 2 y 2 x_1y_1x_2y_2 x1y1x2y2,不是COCO数据集原本的 x 1 y 1 w h x_1y_1wh x1y1wh(左上角宽高)

这里可以验证一下:

import cv2
im = cv2.imread(roidb['im_file'])
x1, y1, x2, y2 = roidb['gt_bbox'][4].astype(int)
xx = cv2.rectangle(im, (x1, y1), (x2, y2), 255, thickness=2, lineType=8)
cv2.imwrite("xxx.png", xx)
6. 基类 __init__ 函数
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.anno_path = anno_path
self.image_dir = image_dir if image_dir is not None else ''
self.data_fields = data_fields           

看上边这4个参数,是和 yaml 文件中的内容是对应的:
在这里插入图片描述
基本都在 parse_dataset 调用

self.sample_num = sample_num                # parse_dataset 中调用
self.use_default_label = use_default_label  # 这个变量可能是 COCO 每个id对应的类名? 暂时没发现使用处
self.repeat = repeat
self._epoch = 0
self._curr_iter = 0
5. 子类 parse_dataset 函数

解析数据集部分,先读取

def parse_dataset(self):
	
	# ------ 先拿到标注和图片的路径 ------
    anno_path = os.path.join(self.dataset_dir, self.anno_path)
    # 'dataset/coco/COCO/annotations/instances_val2017.json'
    image_dir = os.path.join(self.dataset_dir, self.image_dir)
    # 'dataset/coco/COCO/val2017'

    assert anno_path.endswith('.json'), \
        'invalid coco annotation file: ' + anno_path
    from pycocotools.coco import COCO
    coco = COCO(anno_path)
	
	# ------ 拿到每张图片的 img_id ------
    img_ids = coco.getImgIds()
    img_ids.sort()
	
	
    # ------ 拿到COCO数据集类别的 cat_id ------
    cat_ids = coco.getCatIds()
    
    records = []
    empty_records = []
    ct = 0    # 用来进行数据计数的

    self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
    self.cname2cid = dict({
        coco.loadCats(catid)[0]['name']: clsid
        for catid, clsid in self.catid2clsid.items()
    })

    if 'annotations' not in coco.dataset:
        self.load_image_only = True
        logger.warning('Annotation file: {} does not contains ground truth '
                       'and load image information only.'.format(anno_path))

COCO数据集的类别和训练用的类别对应,也就是变量self.catid2clsid

{1: 0,
 2: 1,
 3: 2,
 4: 3,
 5: 4,
 6: 5,
 7: 6,
 8: 7,
 9: 8,
 10: 9,
 11: 10,
 13: 11,
 14: 12,
 15: 13,
 16: 14,
 17: 15,
 18: 16,
 19: 17,
 20: 18,
 21: 19,
 22: 20,
 23: 21,
 24: 22,
 25: 23,
 27: 24,
 28: 25,
 31: 26,
 32: 27,
 33: 28,
 34: 29,
 35: 30,
 36: 31,
 37: 32,
 38: 33,
 39: 34,
 40: 35,
 41: 36,
 42: 37,
 43: 38,
 44: 39,
 46: 40,
 47: 41,
 48: 42,
 49: 43,
 50: 44,
 51: 45,
 52: 46,
 53: 47,
 54: 48,
 55: 49,
 56: 50,
 57: 51,
 58: 52,
 59: 53,
 60: 54,
 61: 55,
 62: 56,
 63: 57,
 64: 58,
 65: 59,
 67: 60,
 70: 61,
 72: 62,
 73: 63,
 74: 64,
 75: 65,
 76: 66,
 77: 67,
 78: 68,
 79: 69,
 80: 70,
 81: 71,
 82: 72,
 84: 73,
 85: 74,
 86: 75,
 87: 76,
 88: 77,
 89: 78,
 90: 79}

变量 self.cname2cid 类别与 id 的对应字典:

{
	'person': 0
	'bicycle': 1
	'car': 2
	'motorcycle': 3
	'airplane': 4
	'bus': 5
	'train': 6
	'truck': 7
	'boat': 8
	'traffic light': 9
	'fire hydrant': 10
	'stop sign': 11
	'parking meter': 12
	'bench': 13
	'bird': 14
	'cat': 15
	'dog': 16
	'horse': 17
	'sheep': 18
	'cow': 19
	'elephant': 20
	'bear': 21
	'zebra': 22
	'giraffe': 23
	'backpack': 24
	'umbrella': 25
	'handbag': 26
	'tie': 27
	'suitcase': 28
	'frisbee': 29
	'skis': 30
	'snowboard': 31
	'sports ball': 32
	'kite': 33
	'baseball bat': 34
	'baseball glove': 35
	'skateboard': 36
	'surfboard': 37
	'tennis racket': 38
	'bottle': 39
	'wine glass': 40
	'cup': 41
	'fork': 42
	'knife': 43
	'spoon': 44
	'bowl': 45
	'banana': 46
	'apple': 47
	'sandwich': 48
	'orange': 49
	'broccoli': 50
	'carrot': 51
	'hot dog': 52
	'pizza': 53
	'donut': 54
	'cake': 55
	'chair': 56
	'couch': 57
	'potted plant': 58
	'bed': 59
	'dining table': 60
	'toilet': 61
	'tv': 62
	'laptop': 63
	'mouse': 64
	'remote': 65
	'keyboard': 66
	'cell phone': 67
	'microwave': 68
	'oven': 69
	'toaster': 70
	'sink': 71
	'refrigerator': 72
	'book': 73
	'clock': 74
	'vase': 75
	'scissors': 76
	'teddy bear': 77
	'hair drier': 78
	'toothbrush': 79
}

接下来这部分开始读取数据

for img_id in img_ids:
	# 拿到当前图片的信息
    img_anno = coco.loadImgs([img_id])[0]
    ‘’‘
    img_anno 的内容:
    {'coco_url': 'http://images.cocodataset.org/val2017/000000000139.jpg',
	 'date_captured': '2013-11-21 01:34:01',
	 'file_name': '000000000139.jpg',
	 'flickr_url': 'http://farm9.staticflickr.com/8035/8024364858_9c41dc1666_z.jpg',
	 'height': 426,
	 'id': 139,
	 'license': 2,
	 'width': 640}
    ’‘’
    
    im_fname = img_anno['file_name']
    im_w = float(img_anno['width'])
    im_h = float(img_anno['height'])
	
	# 拿到本地的图片路径
    im_path = os.path.join(image_dir,
                           im_fname) if image_dir else im_fname
    is_empty = False

	# ------- 判断图片的合法性 ------- 
    if not os.path.exists(im_path):
        logger.warning('Illegal image file: {}, and it will be '
                       'ignored'.format(im_path))
        continue

    if im_w < 0 or im_h < 0:
        logger.warning('Illegal width: {} or height: {} in annotation, '
                       'and im_id: {} will be ignored'.format(
                           im_w, im_h, img_id))
        continue
	
	# 拿到图片的信息,否则是空字典
    coco_rec = {
        'im_file': im_path,
        'im_id': np.array([img_id]),
        'h': im_h,
        'w': im_w,
    } if 'image' in self.data_fields else {}

self.data_fields 是:
在这里插入图片描述

开始根据当前图片 image_id 来读取标注

if not self.load_image_only:
	
	# 拿到图片id对应的标注 ann_id
    ins_anno_ids = coco.getAnnIds(
        imgIds=[img_id], iscrowd=None if self.load_crowd else False)
    
    # 根据标注 ann_id 来读取标注信息
    instances = coco.loadAnns(ins_anno_ids)

    bboxes = []
    is_rbox_anno = False

	# ----------- 加载每一个标注信息 -----------
    for inst in instances:
    	
    	'''
    	inst 的内容
    	{
	    	'segmentation': [[240.86, 211.31, 240.16, 197.19, 236.98, 192.26, 237.34, 187.67, 245.8, ...]]
			'area': 531.8071000000001
			'iscrowd': 0
			'image_id': 139
			'bbox': [236.98, 142.51, 24.7, 69.5]
			'category_id': 64
			'id': 26547
		}
    	'''
    
        # ----- 检查 gt bbox 有效性 -----
        if inst.get('ignore', False):
            continue
        if 'bbox' not in inst.keys():
            continue
        else:
            if not any(np.array(inst['bbox'])):
                continue

		# ---- 注意 COCO 数据集 json 标注的是 左上角+宽高 ----
        x1, y1, box_w, box_h = inst['bbox']
        x2 = x1 + box_w
        y2 = y1 + box_h
		# 这里转化为了 x1y1x2y2
		
		# --------- 接下来检验下 box 有效性 ---------
        eps = 1e-5
        if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
            inst['clean_bbox'] = [
                round(float(x), 3) for x in [x1, y1, x2, y2]
            ]
            bboxes.append(inst)
        else:
            logger.warning(
                'Found an invalid bbox in annotations: im_id: {}, '
                'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                    img_id, float(inst['area']), x1, y1, x2, y2))

接下来将数据存在 numpy array 中:

num_bbox = len(bboxes)
if num_bbox <= 0 and not self.allow_empty:
    continue
elif num_bbox <= 0:
    is_empty = True

# 根据数量创建空的 numpy 数组
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox


has_segmentation = False
for i, box in enumerate(bboxes):

    catid = box['category_id'] # 这个是 COCO 类别要换成 0-79 的
    gt_class[i][0] = self.catid2clsid[catid]
    
    gt_bbox[i, :] = box['clean_bbox']
    is_crowd[i][0] = box['iscrowd']

	
	# --- 由于暂时用不到 segmentation 信息直接跳过 ---
    # check RLE format 
    if 'segmentation' in box and box['iscrowd'] == 1:
        gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
    elif 'segmentation' in box and box['segmentation']:
        if not np.array(box['segmentation']
                        ).size > 0 and not self.allow_empty:
            bboxes.pop(i)
            gt_poly.pop(i)
            np.delete(is_crowd, i)
            np.delete(gt_class, i)
            np.delete(gt_bbox, i)
        else:
            gt_poly[i] = box['segmentation']
        has_segmentation = True

if has_segmentation and not any(
        gt_poly) and not self.allow_empty:
    continue

# 最后将标注信息放在该 gt_rec 字典中
gt_rec = {
    'is_crowd': is_crowd,
    'gt_class': gt_class,
    'gt_bbox': gt_bbox,
    'gt_poly': gt_poly,
}

接下来根据 yaml 中 data_fields 字段将信息放在 coco_rec 字典中

for k, v in gt_rec.items():
	if k in self.data_fields:
	    coco_rec[k] = v
if is_empty: # 如果当前图片没有对应标注
    empty_records.append(coco_rec)
else:
    records.append(coco_rec)
ct += 1# ct 用来计数


# 这个字段可以用来截取数据的长度
if self.sample_num > 0 and ct >= self.sample_num:
    break

最后这部分用来在没有标注的图片列表empty_records中 sample,调用了self._sample_empty

if self.allow_empty and len(empty_records) > 0:
	empty_records = self._sample_empty(empty_records, len(records))
	records += empty_records

附录

顺便备份一下当前版本PaddleDetection2.5COCODataSet 代码

class COCODataSet(DetDataset):
    """
    Load dataset with COCO format.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): coco annotation file path.
        data_fields (list): key name of data dictionary, at least have 'image'.
        sample_num (int): number of samples to load, -1 means all.
        load_crowd (bool): whether to load crowded ground-truth. 
            False as default
        allow_empty (bool): whether to load empty entry. False as default
        empty_ratio (float): the ratio of empty record number to total 
            record's, if empty_ratio is out of [0. ,1.), do not sample the 
            records and use all the empty entries. 1. as default
        repeat (int): repeat times for dataset, use in benchmark.
    """

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 data_fields=['image'],
                 sample_num=-1,
                 load_crowd=False,
                 allow_empty=False,
                 empty_ratio=1.,
                 repeat=1):
        super(COCODataSet, self).__init__(
            dataset_dir,
            image_dir,
            anno_path,
            data_fields,
            sample_num,
            repeat=repeat)
        self.load_image_only = False
        self.load_semantic = False
        self.load_crowd = load_crowd
        self.allow_empty = allow_empty
        self.empty_ratio = empty_ratio

    def _sample_empty(self, records, num):
        # if empty_ratio is out of [0. ,1.), do not sample the records
        if self.empty_ratio < 0. or self.empty_ratio >= 1.:
            return records
        import random
        sample_num = min(
            int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
        records = random.sample(records, sample_num)
        return records

    def parse_dataset(self):
        anno_path = os.path.join(self.dataset_dir, self.anno_path)
        image_dir = os.path.join(self.dataset_dir, self.image_dir)

        assert anno_path.endswith('.json'), \
            'invalid coco annotation file: ' + anno_path
        from pycocotools.coco import COCO
        coco = COCO(anno_path)
        img_ids = coco.getImgIds()
        img_ids.sort()
        cat_ids = coco.getCatIds()
        records = []
        empty_records = []
        ct = 0

        self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
        self.cname2cid = dict({
            coco.loadCats(catid)[0]['name']: clsid
            for catid, clsid in self.catid2clsid.items()
        })

        if 'annotations' not in coco.dataset:
            self.load_image_only = True
            logger.warning('Annotation file: {} does not contains ground truth '
                           'and load image information only.'.format(anno_path))

        for img_id in img_ids:
            img_anno = coco.loadImgs([img_id])[0]
            im_fname = img_anno['file_name']
            im_w = float(img_anno['width'])
            im_h = float(img_anno['height'])

            im_path = os.path.join(image_dir,
                                   im_fname) if image_dir else im_fname
            is_empty = False
            if not os.path.exists(im_path):
                logger.warning('Illegal image file: {}, and it will be '
                               'ignored'.format(im_path))
                continue

            if im_w < 0 or im_h < 0:
                logger.warning('Illegal width: {} or height: {} in annotation, '
                               'and im_id: {} will be ignored'.format(
                                   im_w, im_h, img_id))
                continue

            coco_rec = {
                'im_file': im_path,
                'im_id': np.array([img_id]),
                'h': im_h,
                'w': im_w,
            } if 'image' in self.data_fields else {}

            if not self.load_image_only:
                ins_anno_ids = coco.getAnnIds(
                    imgIds=[img_id], iscrowd=None if self.load_crowd else False)
                instances = coco.loadAnns(ins_anno_ids)

                bboxes = []
                is_rbox_anno = False
                for inst in instances:
                    # check gt bbox
                    if inst.get('ignore', False):
                        continue
                    if 'bbox' not in inst.keys():
                        continue
                    else:
                        if not any(np.array(inst['bbox'])):
                            continue

                    x1, y1, box_w, box_h = inst['bbox']
                    x2 = x1 + box_w
                    y2 = y1 + box_h
                    eps = 1e-5
                    if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
                        inst['clean_bbox'] = [
                            round(float(x), 3) for x in [x1, y1, x2, y2]
                        ]
                        bboxes.append(inst)
                    else:
                        logger.warning(
                            'Found an invalid bbox in annotations: im_id: {}, '
                            'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                img_id, float(inst['area']), x1, y1, x2, y2))

                num_bbox = len(bboxes)
                if num_bbox <= 0 and not self.allow_empty:
                    continue
                elif num_bbox <= 0:
                    is_empty = True

                gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
                gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
                is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
                gt_poly = [None] * num_bbox

                has_segmentation = False
                for i, box in enumerate(bboxes):
                    catid = box['category_id']
                    gt_class[i][0] = self.catid2clsid[catid]
                    gt_bbox[i, :] = box['clean_bbox']
                    is_crowd[i][0] = box['iscrowd']
                    # check RLE format 
                    if 'segmentation' in box and box['iscrowd'] == 1:
                        gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
                    elif 'segmentation' in box and box['segmentation']:
                        if not np.array(box['segmentation']
                                        ).size > 0 and not self.allow_empty:
                            bboxes.pop(i)
                            gt_poly.pop(i)
                            np.delete(is_crowd, i)
                            np.delete(gt_class, i)
                            np.delete(gt_bbox, i)
                        else:
                            gt_poly[i] = box['segmentation']
                        has_segmentation = True

                if has_segmentation and not any(
                        gt_poly) and not self.allow_empty:
                    continue

                gt_rec = {
                    'is_crowd': is_crowd,
                    'gt_class': gt_class,
                    'gt_bbox': gt_bbox,
                    'gt_poly': gt_poly,
                }

                for k, v in gt_rec.items():
                    if k in self.data_fields:
                        coco_rec[k] = v

                # TODO: remove load_semantic
                if self.load_semantic and 'semantic' in self.data_fields:
                    seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
                                            'train2017', im_fname[:-3] + 'png')
                    coco_rec.update({'semantic': seg_path})

            logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
                im_path, img_id, im_h, im_w))
            if is_empty:
                empty_records.append(coco_rec)
            else:
                records.append(coco_rec)
            ct += 1
            if self.sample_num > 0 and ct >= self.sample_num:
                break
        assert ct > 0, 'not found any coco record in %s' % (anno_path)
        logger.debug('{} samples in file {}'.format(ct, anno_path))
        if self.allow_empty and len(empty_records) > 0:
            empty_records = self._sample_empty(empty_records, len(records))
            records += empty_records
        self.roidbs = records

其基类 DetDataset :

from paddle.io import Dataset

class DetDataset(Dataset):
    """
    Load detection dataset.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): annotation file path.
        data_fields (list): key name of data dictionary, at least have 'image'.
        sample_num (int): number of samples to load, -1 means all.
        use_default_label (bool): whether to load default label list.
        repeat (int): repeat times for dataset, use in benchmark.
    """

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 data_fields=['image'],
                 sample_num=-1,
                 use_default_label=None,
                 repeat=1,
                 **kwargs):
        super(DetDataset, self).__init__()
        self.dataset_dir = dataset_dir if dataset_dir is not None else ''
        self.anno_path = anno_path
        self.image_dir = image_dir if image_dir is not None else ''
        self.data_fields = data_fields
        self.sample_num = sample_num
        self.use_default_label = use_default_label
        self.repeat = repeat
        self._epoch = 0
        self._curr_iter = 0

    def __len__(self, ):
        return len(self.roidbs) * self.repeat

    def __call__(self, *args, **kwargs):
        return self

    def __getitem__(self, idx):
        n = len(self.roidbs)
        if self.repeat > 1:
            idx %= n
        # data batch
        roidb = copy.deepcopy(self.roidbs[idx])
        if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
        elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
        elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
            roidb = [roidb, ] + [
                copy.deepcopy(self.roidbs[np.random.randint(n)])
                for _ in range(4)
            ]
        if isinstance(roidb, Sequence):
            for r in roidb:
                r['curr_iter'] = self._curr_iter
        else:
            roidb['curr_iter'] = self._curr_iter
        self._curr_iter += 1
        
        # roidb['num_classes'] = len(self.catid2clsid) # COCODataset 80 cls

        return self.transform(roidb)

    def check_or_download_dataset(self):
        self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
                                            self.image_dir)

    def set_kwargs(self, **kwargs):
        self.mixup_epoch = kwargs.get('mixup_epoch', -1)
        self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
        self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)

    def set_transform(self, transform):
        self.transform = transform

    def set_epoch(self, epoch_id):
        self._epoch = epoch_id

    def parse_dataset(self, ):
        raise NotImplementedError(
            "Need to implement parse_dataset method of Dataset")

    def get_anno(self):
        if self.anno_path is None:
            return
        return os.path.join(self.dataset_dir, self.anno_path)
### 回答1: PaddleDetection VisualDL 是一个可视化开发工具,用于在深度学习任务分析和可视化模型的训练过程和结果。它提供了丰富的可视化功能,可以帮助用户更好地理解、分析和优化模型。 使用 PaddleDetection VisualDL,用户可以在训练过程实时监测模型的训练效果。通过可视化工具,用户可以查看训练损失、准确率等指标的变化情况,并可以通过图表对比不同模型的训练情况。这可以帮助用户及时发现模型训练过程的问题和错误,以便进行调整和优化。 此外,PaddleDetection VisualDL 还提供了可视化的模型结构、权重分布图等功能。通过可视化工具,用户可以直观地了解模型的结构和参数分布情况,帮助用户更好地理解并优化模型。 PaddleDetection VisualDL 还支持高级的可视化功能,如可视化过程特征图、可视化数据增强等。通过可视化过程特征图,用户可以深入理解模型的工作原理,进一步优化模型的性能。同时,通过可视化数据增强,用户可以直观地查看不同数据增强策略对模型性能的影响,有助于选择合适的数据增强方法。 总之,PaddleDetection VisualDL 是一个强大的可视化开发工具,可以帮助用户更好地理解和优化深度学习模型。通过可视化的方式,用户可以实时监测训练过程,分析模型的训练效果,并深入了解模型的结构和参数分布情况,从而优化模型的性能。 ### 回答2: PaddleDetection 是一个基于PaddlePaddle深度学习框架开发的目标检测工具包。它提供了丰富的预训练模型和数据增强方法,可用于各种目标检测任务。 使用 PaddleDetection 进行目标检测,首先需要准备训练和测试数据集。可以使用 PaddleDetection 提供的数据集,也可以自己创建数据集。然后,选择合适的预训练模型,如YOLOv3、Faster R-CNN等,并根据数据集进行模型训练。在训练过程,可以根据需要设置训练参数和优化策略。 完成训练后,可以使用 PaddleDetection 进行目标检测的推理。通过加载已训练好的模型和测试图像,PaddleDetection 可以实时地检测输入图像的目标,并输出检测结果。可以根据需要调整模型的阈值和置信度来控制检测的准确性和召回率。 PaddleDetection 还提供了模型优化和部署的相关工具。可以使用模型优化工具对训练好的模型进行优化,以提高推理速度和降低模型大小。然后,使用模型部署工具将优化后的模型部署到目标设备上,以实现在嵌入式设备、服务器等不同环境进行目标检测。 总之,PaddleDetection 是一个功能强大的目标检测工具包,提供了训练、测试、推理、优化和部署等一系列的功能,可用于各种目标检测任务,并且具有易于使用的接口和灵活的扩展性。 ### 回答3: PaddleDetection是一个高效、灵活和可扩展的目标检测工具包,由百度公司开源。它基于PaddlePaddle深度学习框架,提供了丰富的目标检测模型和算法,具有较高的精度和速度。 使用PaddleDetection进行目标检测的步骤包括数据准备、模型选择、模型训练和模型预测。 首先,需要准备目标检测的数据集。PaddleDetection提供了数据集转换工具,可以将常见的目标检测数据集(如COCO、VOC)转换成PaddleDetection所需的格式。数据集应包含训练集和验证集,并根据目标类别进行标注。 接下来,根据实际应用需求选择适合的目标检测模型。PaddleDetection提供了多种预训练模型,如Faster R-CNN、YOLOv3等。可以选择合适的模型作为基础模型,然后根据数据集进行微调或训练新模型。 进行模型训练时,需要配置训练参数,如学习率、批大小等。PaddleDetection提供了训练脚本,可以根据需求进行修改和配置。在训练过程,可以使用分布式训练来加快训练速度,提高训练效果。 完成模型训练后,可以使用训练好的模型进行目标检测预测。PaddleDetection提供了预测脚本,可以加载训练好的模型,并对图像或视频进行目标检测。预测结果可以包括目标框、类别和得分等信息,可以根据需求进行后续处理和应用。 总而言之,PaddleDetection是一个强大的目标检测工具包,可以帮助用户快速构建和训练目标检测模型,并进行目标检测预测。使用PaddleDetection可以大大简化目标检测的开发流程,提高开发效率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值