多Dataloader随机采样

需求

每次需要从一个datase(有很多个不同的dataset)t中取batchsize个数据,可以从sampler或者dataloader中下手,可以为每个datase构造一个dataloader,然后对这些dataloader进行随机采样。

如今需要将所有 DataLoader 中的 batch 放到一个 DataLoader 中,对外封装成一个普通的 DataLoader,取数据的方式要保持以下这种常规套路:

for i, batch in enumerate(dataloader):
    # do whatever u want
    # ... ...

在 Python 的实现里,要做到这点就要求 dataloader 是一个可迭代的对象(Iterable),你可以实现一个** 迭代器(Iterator)** 或 生成器(generator),然后实例化一个对象作为 dataloader。

迭代器(Iterrator)

由于 generator 通常是函数形式,而基于类的实现方式会更贴近原生的 DataLoader。因此,就这点来看,选取迭代器的实现方式会比较合适。

在 Python 中,一个类要实现为迭代器,通常需要实现两个接口:iter() & next()

iter() 直接返回自身实例即可,而 next() 则需要返回每次迭代时被取出来的数据

另外,迭代器在迭代结束时,也就是所有数据都被取完时,需要抛出 StopIteration 异常作为结束标记

Code

i).每次迭代时从其中1个 DataLoader 中取1个 batch 返回

这种实现方式有个问题需要考虑:如何决定该从哪个 DataLoader 去取数据?

最直接、也最符合我们习惯的应该是随机选取,当然,更“老土”的就是顺序选取了。在顺序选取的情况下,若这些 DataLoader 永远按照固定顺序排列的话,就相当于让模型每个周期都先训练任务A,再训练任务B… 因此,我们可以打乱这些 DataLoader 的顺序,使得模型在下个周期可能先训练任务B,再训练任务A。

import random

from typing import Iterable
from bisect import bisect_right

from accelerate import Accelerator
from torch.utils.data.dataloader import DataLoader


class MultiDataLoader:
    @staticmethod
    def cumsum(seq: Iterable):
        """
        计算累加和。这里输入的 seq 是多个 DataLoader,可以 list, tuple 等等。
        然后将它们的 batch 数量进行累加,依次放入到一个 list 中返回。
        """
        r, s = [], 0
        for sub in seq:
            l = len(sub)
            r.append(s + l)
            s += l

        return r
    
    def __new__(cls, dataloaders: Iterable[DataLoader], *args, **kwargs):
        """
        之所以重写 __new__ 方法,是因为如果仅有1个 DataLoader 的话,
        那就返回它本身的实例就好了,没有必要去使用我们这个自定义的 DataLoader。
        """
        dataloader_list = list(dataloaders)
        # If only one dataloader passed here,
        # we keep using the original type, not modifying anything.
        if len(dataloader_list) == 1:
            # 注意,这种情况下,我们这个类的 __init__() 方法是不会被调用的,因为返回的不是我们这个类的实例。
            return dataloader_list.pop()
        # Otherwise, the truly ConcatDataloader type returns.
        else:
            return super().__new__(cls)
    
    def __init__(
        self, dataloaders: Iterable[DataLoader],
        shuffle: bool = True, random_select: bool = True,
        accelerator: Accelerator = None
    ) -> None:
        dataloader_list = list(dataloaders)
        if len(dataloader_list) > 1:
            # Cuz accelerate only process Pytorch DataLoader, we should do this spliting operation here,
            # instead of doing outside(i.e. in the main program flow)
            if accelerator is not None:
                # 因为 accelerate 仅对 Pytorch 的 DataLoader 做数据划分,所以我们需要在这个类的内部去调用 accelerate 的接口,
                # 而不能在这个类已经实例化之后再调用,因为我们实现的这个类并非 Pytorch DataLoade 的子类。
                dataloader_list = [accelerator.prepare_data_loader(dataloader) for dataloader in dataloader_list]
            self.dataloaders = dataloader_list
    
            # Whether to shuffle the order of batches yield from each dataloader.
            # 是否打乱 DataLoader 的顺序。
            # 如果不打乱,各个 DataLoader 将保持它们加入到 self.dataloader_list 时的顺序。
            self.shuffle = shuffle
            # If true, random select a dataloader to yield batch in each data iteration. 
            # 每次迭代取数据时,是否随机选取1个 DataLoader
            # (否则就是顺序选取,也就是待前面的 DataLoader 的所有 batch 都取完后才轮到下一个)
            self.random_select = random_select
    
            self._init()
    
    def _init(self):
        # Cumulate sum of batch counts.
        self.cum_num_batches = self.cumsum(self.dataloaders)
        self._verify()
    
        # 将各个 DataLoader 变成迭代器,因为我们取数据的方式是:next(dataloader)
        self.dataloader_iters = [iter(dataloader) for dataloader in self.dataloaders]
        if not self.random_select:
            self.batch_idx = 0
    
    def _verify(self):
        if len(self.dataloaders) != len(self.cum_num_batches):
            raise RuntimeError(f"num of dataloaders: {len(self.dataloaders)} != num of cum_num_batches: {len(self.cum_num_batches)}")
    
    def __iter__(self):
        return self
    
    def __next__(self):
        # 随机选取1个 DataLoader 取数据,若被选中的 DataLoader 中的数据已经取完,
        # 那就再随机选另一个,直至所有 DataLoader 都被“掏空”。
        if self.random_select:
            # If the dataloader has been selected is already stopped,
            # kick it out & select another one until there is no one can be picked.
            while self.dataloader_iters:
                self.dataloader_idx = random.choice(range(len(self.dataloader_iters)))
                try:
                    return next(self.dataloader_iters[self.dataloader_idx])
                except StopIteration:
                    self.dataloader_iters.pop(self.dataloader_idx)
        # 非随机选取 DataLoader,那么就根据当前 batch 的索引来决定应该从哪个 DataLoader 中去取数据。
        else:
            if self.batch_idx < len(self):
                # Decide which dataloader should play now
                self.dataloader_idx = bisect_right(self.cum_num_batches, self.batch_idx)
                self.batch_idx += 1
                return next(self.dataloader_iters[self.dataloader_idx])
    
        # 状态重置,以便之后可以再次迭代取数据。
        self._reset()
        # 迭代器在迭代终止时要抛出这个异常
        raise StopIteration
    
    def __len__(self):
        # batch 数量就是所有 DataLoader 的 batch 总数
        return self.cum_num_batches[-1]
    
    def _reset(self):
        if self.shuffle:
            random.shuffle(self.dataloaders)
    
        self._init()
    
    @property
    def batch_size(self):
        # We assume that the batch size of each dataloader equals,
        # so directly use the batch size of the first one dataloader.
        return self.dataloaders[0].batch_size
ii). 每次迭代时从所有 DataLoader 中分别取1个 batch 拼接起来返回

这种实现方式下,有个关键点需要意识到:由于每次迭代都要从所有 DataLoader 中取数据,因此这就要求它们的“长度”(即 batch 数量)是一致的!于是,我们需要先确定哪个 DataLoader 是最长的,然后将其余 DataLoader pad 到等长。也因此,这么一搞,batch 数量就等于最长的那个 DataLoader 的长度。

还有一点可以考虑下,就是我们应该不希望同一种任务的样本都排列在一起,而是希望所有任务的样本在这个拼接起来的 batch 中是随机混在一起的,以致于不会造成“前面10个样本都是任务A的、接下来10个都是任务B的”这种情况。

也正是基于以上的这些考虑,CW 给这个 DataLoader 命名为 ‘MixupDataLoader’

class MixupDataLoader:
    """
    A dataloader of Iterable type mixing up a batch yield from multiple dataloaders,
    'mixup' means concat multiple batches and shuffle the data inside them.
    Note: 
        The final batch size is sum of batch size of each dataloader.
    """

    def __new__(cls, dataloaders: Iterable[DataLoader], *args, **kwargs):
        dataloader_list = list(dataloaders)
        # If only one dataloader, we will keep the original dataloader type
        if len(dataloader_list) == 1:
            return dataloader_list.pop()
        # Otherwise, the truely 'MixupDataLoader' instance returns
        else:
            return super().__new__(cls)
    
    def __init__(self, dataloaders: Iterable[DataLoader], accelerator: Accelerator = None) -> None:
        dataloader_list = list(dataloaders)
        if len(dataloader_list) > 1:
            # The final batch size is sum of batch size of each dataloader.
            self.batch_size = sum(dataloader.batch_size for dataloader in dataloader_list)

            # Cuz accelerate only process Pytorch DataLoader, we should do this spliting operation here,
            # instead of doing outside(i.e. in the main program flow)
            if accelerator is not None:
                dataloader_list = [accelerator.prepare_data_loader(dataloader) for dataloader in dataloader_list]            
            self._dataloaders = dataloader_list

            # 确定哪个 DataLoader 最长
            max_dataloader_len = 0
            for i, dataloader in enumerate(dataloader_list):
                if len(dataloader) > max_dataloader_len:
                    max_idx = i
                    max_dataloader_len = len(dataloader)
            self._main_dataloader_idx = max_idx
            self._num_batches = max_dataloader_len

            # Beside the dataloader which has the most batches,
            # other ones will be set as an cycling iterator,
            # thus making them align with the dataloader that has the most batches. 
            # itertools.cycle() 会将可迭代对象设置为无限循环,然后利用 zip() 就可以将其余 DataLoader 与最长的那个对齐,
            # (zip() 的性质就是将多个可迭代对象与“最短”的那个对齐)
            self.data_chain = iter(
                zip(*[
                    cycle(dataloader) if i != max_idx else dataloader
                    for i, dataloader in enumerate(dataloader_list)
                ])
            )

    def __len__(self):
        return self._num_batches

    def __iter__(self):
        return self

    def __next__(self):
        try:
            # A tuple consist of batches yield from each dataloader.
            # 这里取出来的是由各个 DataLoader 的 batch 组成的元组
            batch_tuple = next(self.data_chain)
            device = batch_tuple[0]['input_ids'].device
            
            # 由于我这里的每个 batch 是 dict,key 是 str, value 是 tensor,
            # 而每种任务的 key 不一定是一样的,因此需要先将所有的 key 都统计出来。
            data_keys = set(batch_tuple[0].keys())
            for batch in batch_tuple:
                data_keys.update(batch.keys())

            # Padding batch for missing keys, cuz data keys of each dataloader may not the same.
            # 如果任务 B 的 batch 没有任务 A 中的 batch 中的某个 key,那就对应“补齐”。
            for k in data_keys:
                for batch in batch_tuple:
                    if k not in batch:
                        if k == 'attention_mask':
                            batch[k] = torch.ones_like(batch['input_ids'])
                        elif 'id' in k:
                            batch[k] = torch.full(
                                (batch['input_ids'].size(0),), -1,
                                device=device, dtype=torch.long
                            )
                        # Leave for future
                        else:
                            pass

            # We compute the sample counts in-time instead of using sum of batch size of each dataloader here
            # incase of the 'batch padding' case.
            num_samples = sum(batch['input_ids'].size(0) for batch in batch_tuple)
            # Useful for mixing up & shuffling data inside each batch yield from each dataloader
            # 生成随机索引以实现打乱(shuffle)的功能
            random_indices = torch.randperm(num_samples, device=device)

            # Concat the data generate from each dataloder
            # 将不同任务的 batch 对应 key 下的 tensor 都“收集”到一起,构造出1个大 batch,这就是最终返回的形式。
            mixup_batch = {}
            for k in data_keys:
                mixup_data = torch.empty(
                    (num_samples,) + batch_tuple[0][k].shape[1:],
                    device=device, dtype=batch_tuple[0][k].dtype
                )
                mixup_data[random_indices] = torch.cat(tuple(batch[k] for batch in batch_tuple))
                mixup_batch[k] = mixup_data
            
            return mixup_batch
        except StopIteration:
            # 状态重置,以便下个周期能够继续迭代取数据。
            self._reset()
            raise StopIteration

    def _reset(self):
        self.data_chain = iter(
            zip(*[
                cycle(dataloader) if i != self._main_dataloader_idx else dataloader
                for i, dataloader in enumerate(self._dataloaders)
            ])
        )

参考:CW

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值