mmdetection源码笔记(三):创建数据集模型之datasets/custom.py的解读(下)

引言

custom.py是datasets/coco.py中CocoDataset的父类,
主要有以下几个类:

  • load_annotations():在子类中重写了,作用是:加载标注文件中的annotation字典,返回图片信息,比如:info{"filename":"284193,faa9000f2678b5e.jpg"}。(存疑)
  • get_ann_info():在子类中被重写了,作用是:获得annotation的信息,其实是调用了_parse_ann_info();它的形参是指定的图片id,返回值是个字典:bboxes,bboxes_ignore, labels, masks, mask_polys, poly_lens.
  • _filter_imgs():在子类中被重写了,作用是:过滤图片,去除没有annotation标注文件的图片,以及图片尺寸小于min-size的图片。
  • _set_group_flag():根据宽高比,为图像分组0和1,保存在flag中,->数组类型
  • __getitem__():在类中定义了__getitem__()方法,那么他的实例对象(假设为P)就可以这样P[key]取值。当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。
  • prepare_train_img():指定的idx,是个数,非列表,单独对一张图片进行训练预处理,可以自己debug一下是什么数据。(推荐用print,博主就是一个一个print出来,看数据的变化,下面的方法也是一样)
  • prepare_test_img():指定的idx,是个数,非列表,单独对一张图片进行测试预处理。(文章结尾处,打印经过prepare_test_img()处理后的数据,及其格式)

代码中还出现了其他文件的几个类:
DC() 、ImageTransform()、BboxTransform()、MaskTransform()、SegMapTransform()、Numpy2Tensor()、ExtraAugmentation()、to_tensor()、random_scale()
这些类,有的我已经在下面标明作用,有的还没,后面会继续更新对代码的理解。(其实都无关紧要,就是对数据的处理加工而已,只需知道加工后的数据格式是什么样的就行咯)

custom.py

初始化部分+不重要的方法
import os.path as osp
import mmcv
import numpy as np
from mmcv.parallel import DataContainer as DC # A container for any type of objects 
from torch.utils.data import Dataset
from .registry import DATASETS
from .transforms import (ImageTransform, BboxTransform, MaskTransform,
                         SegMapTransform, Numpy2Tensor)
from .utils import to_tensor, random_scale
from .extra_aug import ExtraAugmentation
@DATASETS.register_module
class CustomDataset(Dataset):
    CLASSES = None                           # 在子类中被覆盖
    def __init__(self,
                 ann_file,                   # 标注文件
                 img_prefix,                 # 图片路径
                 img_scale,                  # 图片尺寸
                 img_norm_cfg,               # 输入图像初始化,减去均值mean并处以方差std,to_rgb表示将bgr转为rgb 
                 multiscale_mode='value',    
                 size_divisor=None,          # 32,对图像进行resize时的最小单位,32表示所有的图像都会被resize成32的倍数
                 proposal_file=None,         # 候选框文件
                 num_max_proposals=1000,     # 候选框最大数量
                 flip_ratio=0,               # 图像的随机左右翻转的概率
                 with_mask=True,             # 训练(测试)时附带mask
                 with_crowd=True,            # 附带difficult的样本
                 with_label=True,            # 附带label
                 with_semantic_seg=False,    # 附带semantic_seg
                 seg_prefix=None,            # seg路径
                 seg_scale_factor=1,
                 extra_aug=None,             # 额外的增强措施 (有待发现)
                 resize_keep_ratio=True,
                 test_mode=False):           # 默认初始化为false,处于训练阶段
        # prefix of images path
        self.img_prefix = img_prefix       
        
        # 下面的load_annotations()、_filter_imgs()等在子类中定义的类,父类初始化的时候,调用的其实是子类的重写的函数。
        # load annotations (and proposals)
        self.img_infos = self.load_annotations(ann_file)        
        # 一个例子:比如info{'file_name': '273278,e118d000ec53d5cd.jpg', 'height': 1365, 'width': 2048, 'id': 4370, 'filename': '273278,e118d000ec53d5cd.jpg'} 
        if proposal_file is not None:        					 # 此处为None
            self.proposals = self.load_proposals(proposal_file)  # 不为None的话,那proposals是什么样数据类型???????传入的proposal又是什么样的数据类型???,有待解决
        else:
            self.proposals = None
            
        # filter images with no annotation during training
        # build数据集并初始化时,根据test_mode的值,来对数据进行处理。重新初始化 img_infos 和 proposals
        if not test_mode:                                        # test_mode 为 false 表示是处于train阶段咯,不是测试阶段
            valid_inds = self._filter_imgs()                     # valid_inds 获得有效图片的 ID
            self.img_infos = [self.img_infos[i] for i in valid_inds]
            if self.proposals is not None:
                self.proposals = [self.proposals[i] for i in valid_inds]
        # (long_edge, short_edge) or [(long1, short1), (long2, short2), ...]
        self.img_scales = img_scale if isinstance(img_scale,
                                                  list) else [img_scale]
        assert mmcv.is_list_of(self.img_scales, tuple)           # 是否是元组类型(在配置文件中查看)
        # normalization configs
        self.img_norm_cfg = img_norm_cfg
        # multi-scale mode (only applicable for multi-scale training)
        self.multiscale_mode = multiscale_mode
        assert multiscale_mode in <
  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值