训练准备工作(六)

def train_model(model,
                dataset,
                cfg,
                validate=False,
                test=dict(test_best=False, test_last=False),
                timestamp=None,
                meta=None):
    """Train model entry function.

    Args:
        model (nn.Module): The model to be trained.
        dataset (:obj:`Dataset`): Train dataset.
        cfg (dict): The config dict for training.
        validate (bool): Whether to do evaluation. Default: False.
        test (dict): The testing option, with two keys: test_last & test_best.
            The value is True or False, indicating whether to test the
            corresponding checkpoint.
            Default: dict(test_best=False, test_last=False).
        timestamp (str | None): Local time for runner. Default: None.
        meta (dict | None): Meta dict to record some important information.
            Default: None
    """
    logger = get_root_logger(log_level=cfg.get('log_level', 'INFO'))

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    dataloader_setting = dict(
        videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
        workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
        persistent_workers=cfg.data.get('persistent_workers', False),
        seed=cfg.seed)
    dataloader_setting = dict(dataloader_setting,
                              **cfg.data.get('train_dataloader', {}))

    data_loaders = [
        build_dataloader(ds, **dataloader_setting) for ds in dataset
    ]

data_loaders的数据类型是列表,[<torch.utils.data.dataloader.DataLoader object at 0x7f4b3bfc7b90>]

dataset 变量是列表,[<pyskl.datasets.dataset_wrappers.RepeatDataset object at 0x7fe7b3668550>]

既然 dataset 变量是一个包含 RepeatDataset 对象的列表,那么这段代码的实现就更加清晰了。

RepeatDataset 是一个数据集包装器,可以对原始数据集进行重复采样,以增加训练样本的数量。

在这种情况下,build_dataloader() 函数会为 dataset 列表中的每个 RepeatDataset 对象创建一个对应的 DataLoader 对象,并将它们存储在 data_loaders 列表中。

这种设计方式有以下好处:

  1. 可以同时处理多个数据集,每个数据集都有独立的 DataLoader
  2. 通过使用 RepeatDataset,可以对数据集进行重复采样,增加训练样本的数量。
  3. 将 DataLoader 对象存储在 data_loaders 列表中,方便后续的训练和评估操作。

总的来说,这段代码的实现逻辑是合理的,能够灵活地处理多个数据集,并对数据进行重复采样。

怎么发现data_loaders不应该是DataLoader类进行实例化的吗?比如train_dataloader = DataLoader(train_data, batch-size=16)这样?

事实上,在这段代码中,data_loaders 是正确地使用 DataLoader 类进行实例化的。

让我们回顾一下您提到的示例代码:

train_dataloader = DataLoader(train_data, batch_size=16)

这段代码是正确的,它创建了一个 DataLoader 对象,用于加载训练数据。

在前面的代码中,data_loaders 是一个列表,其中包含了多个 DataLoader 对象。这是因为原始的 dataset 变量可能包含多个数据集对象,因此需要为每个数据集创建一个 DataLoader

所以,data_loaders 是正确的,它是一个包含多个 DataLoader 对象的列表,而不是直接使用 DataLoader 类进行实例化。

相比之下,您提到的 train_dataloader = DataLoader(train_data, batch_size=16) 是一个单独的 DataLoader 对象,用于加载训练数据。这种方式也是正确的,只是与原始代码中的 data_loaders 实现有所不同。

总结来说,原始代码中的 data_loaders 是正确的,它是一个 DataLoader 对象的列表,用于处理可能包含多个数据集的情况。您提到的 train_dataloader = DataLoader(train_data, batch_size=16) 也是正确的,只是处理了单个训练数据集的情况。两种实现方式都是合理的,只是根据具体的需求而采用不同的方式。

pyskl/pyskl/apis/train.py at main · kennymckormick/pyskl · GitHub

  • 18
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值