Detecron2源码阅读-训练器(3)

一个有意思的问题,DefaultTrainer继承的是TrainerBase,而不是SimpleTrainer。我们着重研究一下Trainer怎么写,还有多GPU训练的问题。

还有一个想吐槽的点是detectron2全程通过cfg来配置,但是也没教俺们怎么写config。
tips:config的说明在 config/defaults.py里

首先看DefaultTrainer,分为以下几类:

  1. build方法:build_hooks,build_writers,build_model,build_optimizer,build_lr_scheduler,build_train_loader,build_test_loader,build_evaluator
  2. 训练逻辑:resume_or_load,train,run_step
  3. 其他:state_dict,load_state_dict
  4. auto_scale_workers

我们先看init方法:

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())

        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        model = create_ddp_model(model, broadcast_buffers=False)
        self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
            model, data_loader, optimizer
        )

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            trainer=weakref.proxy(self),
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

1. 配置logger

2. 通过auto_scale_workers方法来自动修改batch_size,学习率等参数

3. build model,optimizer,data_loader

4. 针对分布式训练,自动转换model

我们先看 auto_scale_workers

IMS_PER_BATCH: 16
BASE_LR: 0.1
REFERENCE_WORLD_SIZE: 8
MAX_ITER: 5000
STEPS: (4000,)
CHECKPOINT_PERIOD: 1000

IMS_PER_BATCH是每个step的batch size,如果IMS_PER_BATCH为16,有8个GPU,那么每个GPU会拿到2个图像。
REFERENCE_WORLD_SIZE默认为0,意思是有几个GPU用来训练

这个函数用于读取以前的配置,并自动根据当前的GPU数量对参数进行修改。

scale = num_workers / old_world_size
bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))

假设,原来用8个GPU,现在用16个GPU,则scale=2,batch size乘以2,学习率也*2,max_iter数量/2

5. 构建_trainer

这里的trainer使用的就是SimpleTrainer或AMPTrainer

6. build lr_scheduler

这里需要optimizer转入

7. build checkpointer

8. 配置start_iter,max_iter,cfg

9. build_hooks&register_hooks

并行

有一个难懂的机制是detectron2的并行化机制,


def launch(
    main_func,
    # Should be num_processes_per_machine, but kept for compatibility.
    num_gpus_per_machine,
    num_machines=1,
    machine_rank=0,
    dist_url=None,
    args=(),
    timeout=DEFAULT_TIMEOUT,
):

dist_url我单机多卡训练时,一般设置为"auto"

    world_size = num_machines * num_gpus_per_machine
    if world_size > 1:
        # https://github.com/pytorch/pytorch/pull/14391
        # TODO prctl in spawned processes

        if dist_url == "auto":
            assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
            port = _find_free_port()
            dist_url = f"tcp://127.0.0.1:{port}"
        if num_machines > 1 and dist_url.startswith("file://"):
            logger = logging.getLogger(__name__)
            logger.warning(
                "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
            )

        mp.start_processes(
            _distributed_worker,
            nprocs=num_gpus_per_machine,
            args=(
                main_func,
                world_size,
                num_gpus_per_machine,
                machine_rank,
                dist_url,
                args,
                timeout,
            ),
            daemon=False,
        )
    else:
        main_func(*args)
  • _distributed_worker:这是在每个子进程中要执行的函数。
  • nprocs=num_gpus_per_machine:指定要启动的子进程的数量,通常这里会指定每台机器上的GPU数量。
  • args:这是一个包含传递给_distributed_worker函数的参数的元组。
    • main_func:主要处理逻辑的函数。
    • world_size:指定分布式训练中的总进程数量。
    • num_gpus_per_machine:每台机器上的GPU数量。
    • machine_rank:机器的排名。
    • dist_url:用于进程间通信的URL。
    • args:其他参数。
    • timeout:超时时间。
  • daemon=False:这表明主进程会等待所有子进程执行完毕,如果设置为True,则主进程不会等待子进程执行完毕。
def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
    mp = multiprocessing.get_context(start_method)
    error_queues = []
    processes = []
    for i in range(nprocs):
        error_queue = mp.SimpleQueue()
        process = mp.Process(
            target=_wrap,
            args=(fn, i, args, error_queue),
            daemon=daemon,
        )
        process.start()
        error_queues.append(error_queue)
        processes.append(process)

    context = ProcessContext(processes, error_queues)
    if not join:
        return context

    # Loop on join until it returns True or raises an exception.
    while not context.join():
        pass

实际传入的函数是def _wrap(fn, i, args, error_queue):fn是目标函数,i是进程编号,

def _distributed_worker(
    local_rank,
    main_func,
    world_size,
    num_gpus_per_machine,
    machine_rank,
    dist_url,
    args,
    timeout=DEFAULT_TIMEOUT,
):

local_rank是在本地机器上的rank,

    has_gpu = torch.cuda.is_available()
    if has_gpu:
        assert num_gpus_per_machine <= torch.cuda.device_count()
  • 检查当前环境中是否有可用的GPU。
  • 如果存在GPU,则确保每台机器上的GPU数量不超过系统中实际存在的GPU数量。
  • 根据机器排名和本地排名计算全局排名。
    global_rank = machine_rank * num_gpus_per_machine + local_rank
    try:
        dist.init_process_group(
            backend="NCCL" if has_gpu else "GLOO",
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank,
            timeout=timeout,
        )
    except Exception as e:
        logger = logging.getLogger(__name__)
        logger.error("Process group URL: {}".format(dist_url))
        raise e

    # Setup the local process group.
    comm.create_local_process_group(num_gpus_per_machine)
    if has_gpu:
        torch.cuda.set_device(local_rank)

    # synchronize is needed here to prevent a possible timeout after calling init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    comm.synchronize()

    main_func(*args)

当我进行训练时,有个问题就是在使用多GPU训练的时候,data_loader的问题是怎么解决的?
我们以build_detection_train_loader为例看看


@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset,
    *,
    mapper,
    sampler=None,
    total_batch_size,
    aspect_ratio_grouping=True,
    num_workers=0,
    collate_fn=None,
    **kwargs
):
    """
    Build a dataloader for object detection with some default features.

    Args:
        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
            or a pytorch dataset (either map-style or iterable). It can be obtained
            by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
            indices to be applied on ``dataset``.
            If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
            which coordinates an infinite random shuffle sequence across all workers.
            Sampler must be None if ``dataset`` is iterable.
            Sampler must be None if ``dataset`` is iterable.
        total_batch_size (int): total batch size across all workers.
        aspect_ratio_grouping (bool): whether to group images with similar
            aspect ratio for efficiency. When enabled, it requires each
            element in dataset be a dict with keys "width" and "height".
        num_workers (int): number of parallel data loading workers
        collate_fn: a function that determines how to do batching, same as the argument of
            `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
            data. No collation is OK for small batch size and simple data structures.
            If your batch size is large and each sample contains too many small tensors,
            it's more efficient to collate them in data loader.

    Returns:
        torch.utils.data.DataLoader:
            a dataloader. Each output from it is a ``list[mapped_element]`` of length
            ``total_batch_size / num_workers``, where ``mapped_element`` is produced
            by the ``mapper``.
    """
    if isinstance(dataset, list):
        dataset = DatasetFromList(dataset, copy=False)
    if mapper is not None:
        dataset = MapDataset(dataset, mapper)

    if isinstance(dataset, torchdata.IterableDataset):
        assert sampler is None, "sampler must be None if dataset is IterableDataset"
    else:
        if sampler is None:
            sampler = TrainingSampler(len(dataset))
        assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
    return build_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
        collate_fn=collate_fn,
        **kwargs
    )

可见,build_detection_train_loader函数中的dataset,支持两种形式的输入,List或torch.utils.data.Dataset,整个函数的逻辑为:

  1. 如果输入的dataset为list,就使用DatasetFromList方法,以wrapper的形式转换为torch.utils.data.Dataset;
  2. 如果存在mapper,就使用mapper方法对data做一遍转换;
  3. 对sampler做检查
  4. 使用build_batch_data_loader方法,返回data_loader;

可见,build_detection_train_loader方法也比较像一个wrapper,真正干活的还是build_batch_data_loader


def build_batch_data_loader(
    dataset,
    sampler,
    total_batch_size,
    *,
    aspect_ratio_grouping=False,
    num_workers=0,
    collate_fn=None,
    drop_last: bool = True,
    single_gpu_batch_size=None,
    **kwargs,
):
    """
    Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
    1. support aspect ratio grouping options
    2. use no "batch collation", because this is common for detection training

    Args:
        dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
            Must be provided iff. ``dataset`` is a map-style dataset.
        total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
            :func:`build_detection_train_loader`.
        single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`.
            `single_gpu_batch_size` specifies the batch size that will be used for each gpu/process.
            `total_batch_size` allows you to specify the total aggregate batch size across gpus.
            It is an error to supply a value for both.
        drop_last (bool): if ``True``, the dataloader will drop incomplete batches.

    Returns:
        iterable[list]. Length of each list is the batch size of the current
            GPU. Each element in the list comes from the dataset.
    """
    if single_gpu_batch_size:
        if total_batch_size:
            raise ValueError(
                """total_batch_size and single_gpu_batch_size are mutually incompatible.
                Please specify only one. """
            )
        batch_size = single_gpu_batch_size
    else:
        world_size = get_world_size()
        assert (
            total_batch_size > 0 and total_batch_size % world_size == 0
        ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
            total_batch_size, world_size
        )
        batch_size = total_batch_size // world_size
    logger = logging.getLogger(__name__)
    logger.info("Making batched data loader with batch_size=%d", batch_size)



    if isinstance(dataset, torchdata.IterableDataset):
        print("dataset type is torchdata.IterableDataset")
        assert sampler is None, "sampler must be None if dataset is IterableDataset"
    else:
        print("dataset type is not torchdata.IterableDataset")
        dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)

    if aspect_ratio_grouping:
        assert drop_last, "Aspect ratio grouping will drop incomplete batches."
        data_loader = torchdata.DataLoader(
            dataset,
            num_workers=num_workers,
            collate_fn=operator.itemgetter(0),  # don't batch, but yield individual elements
            worker_init_fn=worker_init_reset_seed,
            **kwargs
        )  # yield individual mapped dict
        data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
        if collate_fn is None:
            return data_loader
        return MapDataset(data_loader, collate_fn)
    else:
        return torchdata.DataLoader(
            dataset,
            batch_size=batch_size,
            drop_last=drop_last,
            num_workers=num_workers,
            collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
            worker_init_fn=worker_init_reset_seed,
            **kwargs
        )

首先,必须传入的参数有三个,分别是dataset,sampler,total_batch_size。

  1. 首先判断batch_size,根据GPU和机器数量计算每个GPU应该传入的batch size
  2. dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)调用ToIterableDataset方法,再wrapper一遍

class ToIterableDataset(data.IterableDataset):
    """
    Convert an old indices-based (also called map-style) dataset
    to an iterable-style dataset.
    """

    def __init__(
        self,
        dataset: data.Dataset,
        sampler: Sampler,
        shard_sampler: bool = True,
        shard_chunk_size: int = 1,
    ):
        """
        Args:
            dataset: an old-style dataset with ``__getitem__``
            sampler: a cheap iterable that produces indices to be applied on ``dataset``.
            shard_sampler: whether to shard the sampler based on the current pytorch data loader
                worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple
                workers, it is responsible for sharding its data based on worker id so that workers
                don't produce identical data.

                Most samplers (like our TrainingSampler) do not shard based on dataloader worker id
                and this argument should be set to True. But certain samplers may be already
                sharded, in that case this argument should be set to False.
            shard_chunk_size: when sharding the sampler, each worker will
        """
        assert not isinstance(dataset, data.IterableDataset), dataset
        assert isinstance(sampler, Sampler), sampler
        self.dataset = dataset
        self.sampler = sampler
        self.shard_sampler = shard_sampler
        self.shard_chunk_size = shard_chunk_size

    def __iter__(self):
        if not self.shard_sampler:
            sampler = self.sampler
        else:
            # With map-style dataset, `DataLoader(dataset, sampler)` runs the
            # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))`
            # will run sampler in every of the N worker. So we should only keep 1/N of the ids on
            # each worker. The assumption is that sampler is cheap to iterate so it's fine to
            # discard ids in workers.
            sampler = _shard_iterator_dataloader_worker(self.sampler, self.shard_chunk_size)
        for idx in sampler:
            yield self.dataset[idx]

    def __len__(self):
        return len(self.sampler)

其中shard_samplershard_chunk_size 用于控制数据采样器(sampler)在分布式设置中的行为

  1. shard_sampler

    • shard_sampler 是一个布尔类型的参数,用于指示是否需要根据当前 PyTorch 数据加载器(DataLoader)的 worker id 来对采样器进行分片。
    • shard_sampler 为 True 时,数据加载过程中会确保每个 worker 只处理数据采样器中的一部分,以避免重复加载数据。
    • 在分布式设置中,不同 worker 应该处理不同的数据样本,因此在多个 worker 之间必须对采样器进行合理的切分,确保数据独立性。
  2. shard_chunk_size

    • shard_chunk_size 是一个整数参数,它指定了在对采样器进行分片时,每个 worker 保留的索引数量。
    • 在数据加载的过程中,shard_chunk_size 控制了每个 worker 保留的采样器索引范围大小,以确保在分片时数据划分合理。
    • 通过调整 shard_chunk_size 的值,可以对每个 worker 处理的数据量进行精细调节,适应不同的分布式训练需求。
def _shard_iterator_dataloader_worker(iterable, chunk_size=1):
    # Shard the iterable if we're currently inside pytorch dataloader worker.
    worker_info = data.get_worker_info()
    if worker_info is None or worker_info.num_workers == 1:
        # do nothing
        yield from iterable
    else:
        # worker0: 0, 1, ..., chunk_size-1, num_workers*chunk_size, num_workers*chunk_size+1, ...
        # worker1: chunk_size, chunk_size+1, ...
        # worker2: 2*chunk_size, 2*chunk_size+1, ...
        # ...
        yield from _roundrobin(
            *[
                itertools.islice(
                    iterable,
                    worker_info.id * chunk_size + chunk_i,
                    None,
                    worker_info.num_workers * chunk_size,
                )
                for chunk_i in range(chunk_size)
            ]
        )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值