maskrcnn_benchmark源码解析——dataloader篇(上)

由于参考代码Stitcher主要改动的就是dataloader,因此先介绍dataloader

参考代码https://github.com/yukang2017/Stitcher/tree/master/maskrcnn_benchmark

本篇涉及pytorch源码内容较多,篇幅可能稍长,分为上、下两部分讲


同样从train_net.py起

#train_net.py
'''
from maskrcnn_benchmark.config import cfg
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
'''
	model = train(cfg, args.local_rank, args.distributed)

def train(cfg, local_rank, distributed):
    '''
    省略内容
    '''
	#from maskrcnn_benchmark.data import make_data_loader
    data_loader = make_data_loader( #看这里
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        distributed,
    )

    return model

from maskrcnn_benchmark.data import make_data_loader:导入的包,同样先看data目录里的__init_.py

#__init__.py
from .build import make_data_loader
#build.py
def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0, batch_stitch=False):
    num_gpus = get_world_size()
    if is_train:#训练模式
        images_per_batch = cfg.SOLVER.IMS_PER_BATCH#batch_size
        if batch_stitch:#如果用batch的stitch,即图片长宽缩小一半后按batch维度拼,则同样内存占用的情况下,batch_size扩大4倍
            #[n,c,h,w]->[4n,c,h/2,w/2]
            images_per_batch *= 4
        assert (#batch_size必须是gpu数量的整数倍
            images_per_batch % num_gpus == 0
        ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format(
            images_per_batch, num_gpus)
        images_per_gpu = images_per_batch // num_gpus
        shuffle = True#注意这个shuffle
        num_iters = cfg.SOLVER.MAX_ITER#最大迭代次数
    else:
        images_per_batch = cfg.TEST.IMS_PER_BATCH
        assert (
            images_per_batch % num_gpus == 0
        ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format(
            images_per_batch, num_gpus)
        images_per_gpu = images_per_batch // num_gpus
        shuffle = False if not is_distributed else True
        num_iters = None
        start_iter = 0

    if images_per_gpu > 1:
        logger = logging.getLogger(__name__)
        logger.warning(
            "When using more than one image per GPU you may encounter "
            "an out-of-memory (OOM) error if your GPU does not have "
            "sufficient memory. If this happens, you can reduce "
            "SOLVER.IMS_PER_BATCH (for training) or "
            "TEST.IMS_PER_BATCH (for inference). For training, you must "
            "also adjust the learning rate and schedule length according "
            "to the linear scaling rule. See for example: "
            "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14"
        )

    # group images which have similar aspect ratio. In this case, we only
    # group in two cases: those with width / height > 1, and the other way around,
    # but the code supports more general grouping strategy
    aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []#是否按长宽比分组,即长宽比>1的才能在同一个batch。因为同一batch图片大小需相等,以最大的长、宽为目标。如果长宽比相差太大会导致大量padding

    paths_catalog = import_file(#导入各个数据、模型的位置config
        "maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
    )
    DatasetCatalog = paths_catalog.DatasetCatalog
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

    # If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
    
    transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train, batch_stitch=batch_stitch)#如果在训练 则build_transforms,测试的时候把transforms置None
    datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)#第169行
    ### 暂时就先看到这里

最后两行build_transforms和build_dataset
(1)from .transforms import build_transforms
看data/transforms文件夹下

#__init__.py
from .transforms import Compose
from .transforms import Resize
from .transforms import RandomHorizontalFlip
from .transforms import ToTensor
from .transforms import Normalize

from .build import build_transforms

#__build__.py
from . import transforms as T
#这里引用的是他自己所在的这个包,看上面的__init__.py,transforms包含上面的所有对象

def build_transforms(cfg, is_train=True, batch_stitch=False):
    if is_train:
        min_size = cfg.INPUT.MIN_SIZE_TRAIN
        max_size = cfg.INPUT.MAX_SIZE_TRAIN
        if batch_stitch:#如果用batch做stitch,则每张图片缩小为1/2(不过两种stitch方法都得缩小1/2啊)
            min_size = tuple(s//2 for s in list(min_size))#min_size (800,)
            max_size //= 2#max_size:1333
        flip_horizontal_prob = 0.5  # cfg.INPUT.FLIP_PROB_TRAIN
        flip_vertical_prob = cfg.INPUT.VERTICAL_FLIP_PROB_TRAIN
        brightness = cfg.INPUT.BRIGHTNESS
        contrast = cfg.INPUT.CONTRAST
        saturation = cfg.INPUT.SATURATION
        hue = cfg.INPUT.HUE
    else:
        min_size = cfg.INPUT.MIN_SIZE_TEST
        max_size = cfg.INPUT.MAX_SIZE_TEST
        flip_horizontal_prob = 0.0
        flip_vertical_prob = 0.0
        brightness = 0.0
        contrast = 0.0
        saturation = 0.0
        hue = 0.0

    to_bgr255 = cfg.INPUT.TO_BGR255#转成caffe的图片读取模式,BGR,0~255
    normalize_transform = T.Normalize(
        mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255
    )
    color_jitter = T.ColorJitter(
        brightness=brightness,
        contrast=contrast,
        saturation=saturation,
        hue=hue,
    )

    transform = T.Compose(
        [
            color_jitter,
            T.Resize(min_size, max_size),#注意这个Resize
            T.RandomHorizontalFlip(flip_horizontal_prob),
            T.RandomVerticalFlip(flip_vertical_prob),
            T.ToTensor(),
            normalize_transform,
        ]
    )
    return transform

build_transforms根据cfg里面的布尔值,为dataloader添加各种transforms,这里主要讲一下Resize

class Resize(object):
    def __init__(self, min_size, max_size):
        if not isinstance(min_size, (list, tuple)):
            min_size = (min_size,)
        self.min_size = min_size
        self.max_size = max_size

    # modified from torchvision to add support for max size
    def get_size(self, image_size):#保持长宽比不变,同时使得图像满足max/min的resize大小
		***
		***
        return (oh, ow)

    def __call__(self, image, target=None):
        size = self.get_size(image.size)
        image = F.resize(image, size)#根据min_size,max_size和原图大小,获取要resize的大小后,直接用resize插值
        if target is None:
            return image
        target = target.resize(image.size)#这里把target图像也做了相同resize?
        return image, target

由此可得,build_transforms返回T.compose的对象,包含了各种图像处理操作,且如果做batch_stitch的话,所有图像要直接缩小为1/2
如果不做batch_stitch,可能后面也会缩小
(2)datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)
dataset_list是一个字符串list,包含用于训练的数据集名字
就是返回一个dataset对象,把list里的dataset合并成一个,同时应用各种transforms

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值