需求
每次需要从一个datase(有很多个不同的dataset)t中取batchsize个数据,可以从sampler或者dataloader中下手,可以为每个datase构造一个dataloader,然后对这些dataloader进行随机采样。
如今需要将所有 DataLoader 中的 batch 放到一个 DataLoader 中,对外封装成一个普通的 DataLoader,取数据的方式要保持以下这种常规套路:
for i, batch in enumerate(dataloader):
# do whatever u want
# ... ...
在 Python 的实现里,要做到这点就要求 dataloader 是一个可迭代的对象(Iterable),你可以实现一个** 迭代器(Iterator)** 或 生成器(generator),然后实例化一个对象作为 dataloader。
迭代器(Iterrator)
由于 generator 通常是函数形式,而基于类的实现方式会更贴近原生的 DataLoader。因此,就这点来看,选取迭代器的实现方式会比较合适。
在 Python 中,一个类要实现为迭代器,通常需要实现两个接口:iter() & next()
iter() 直接返回自身实例即可,而 next() 则需要返回每次迭代时被取出来的数据
另外,迭代器在迭代结束时,也就是所有数据都被取完时,需要抛出 StopIteration 异常作为结束标记
Code
i).每次迭代时从其中1个 DataLoader 中取1个 batch 返回
这种实现方式有个问题需要考虑:如何决定该从哪个 DataLoader 去取数据?
最直接、也最符合我们习惯的应该是随机选取,当然,更“老土”的就是顺序选取了。在顺序选取的情况下,若这些 DataLoader 永远按照固定顺序排列的话,就相当于让模型每个周期都先训练任务A,再训练任务B… 因此,我们可以打乱这些 DataLoader 的顺序,使得模型在下个周期可能先训练任务B,再训练任务A。
import random
from typing import Iterable
from bisect import bisect_right
from accelerate import Accelerator
from torch.utils.data.dataloader import DataLoader
class MultiDataLoader:
@staticmethod
def cumsum(seq: Iterable):
"""
计算累加和。这里输入的 seq 是多个 DataLoader,可以 list, tuple 等等。
然后将它们的 batch 数量进行累加,依次放入到一个 list 中返回。
"""
r, s = [], 0
for sub in seq:
l = len(sub)
r.append(s + l)
s += l
return r
def __new__(cls, dataloaders: Iterable[DataLoader], *args, **kwargs):
"""
之所以重写 __new__ 方法,是因为如果仅有1个 DataLoader 的话,
那就返回它本身的实例就好了,没有必要去使用我们这个自定义的 DataLoader。
"""
dataloader_list = list(dataloaders)
# If only one dataloader passed here,
# we keep using the original type, not modifying anything.
if len(dataloader_list) == 1:
# 注意,这种情况下,我们这个类的 __init__() 方法是不会被调用的,因为返回的不是我们这个类的实例。
return dataloader_list.pop()
# Otherwise, the truly ConcatDataloader type returns.
else:
return super().__new__(cls)
def __init__(
self, dataloaders: Iterable[DataLoader],
shuffle: bool = True, random_select: bool = True,
accelerator: Accelerator = None
) -> None:
dataloader_list = list(dataloaders)
if len(dataloader_list) > 1:
# Cuz accelerate only process Pytorch DataLoader, we should do this spliting operation here,
# instead of doing outside(i.e. in the main program flow)
if accelerator is not None:
# 因为 accelerate 仅对 Pytorch 的 DataLoader 做数据划分,所以我们需要在这个类的内部去调用 accelerate 的接口,
# 而不能在这个类已经实例化之后再调用,因为我们实现的这个类并非 Pytorch DataLoade 的子类。
dataloader_list = [accelerator.prepare_data_loader(dataloader) for dataloader in dataloader_list]
self.dataloaders = dataloader_list
# Whether to shuffle the order of batches yield from each dataloader.
# 是否打乱 DataLoader 的顺序。
# 如果不打乱,各个 DataLoader 将保持它们加入到 self.dataloader_list 时的顺序。
self.shuffle = shuffle
# If true, random select a dataloader to yield batch in each data iteration.
# 每次迭代取数据时,是否随机选取1个 DataLoader
# (否则就是顺序选取,也就是待前面的 DataLoader 的所有 batch 都取完后才轮到下一个)
self.random_select = random_select
self._init()
def _init(self):
# Cumulate sum of batch counts.
self.cum_num_batches = self.cumsum(self.dataloaders)
self._verify()
# 将各个 DataLoader 变成迭代器,因为我们取数据的方式是:next(dataloader)
self.dataloader_iters = [iter(dataloader) for dataloader in self.dataloaders]
if not self.random_select:
self.batch_idx = 0
def _verify(self):
if len(self.dataloaders) != len(self.cum_num_batches):
raise RuntimeError(f"num of dataloaders: {len(self.dataloaders)} != num of cum_num_batches: {len(self.cum_num_batches)}")
def __iter__(self):
return self
def __next__(self):
# 随机选取1个 DataLoader 取数据,若被选中的 DataLoader 中的数据已经取完,
# 那就再随机选另一个,直至所有 DataLoader 都被“掏空”。
if self.random_select:
# If the dataloader has been selected is already stopped,
# kick it out & select another one until there is no one can be picked.
while self.dataloader_iters:
self.dataloader_idx = random.choice(range(len(self.dataloader_iters)))
try:
return next(self.dataloader_iters[self.dataloader_idx])
except StopIteration:
self.dataloader_iters.pop(self.dataloader_idx)
# 非随机选取 DataLoader,那么就根据当前 batch 的索引来决定应该从哪个 DataLoader 中去取数据。
else:
if self.batch_idx < len(self):
# Decide which dataloader should play now
self.dataloader_idx = bisect_right(self.cum_num_batches, self.batch_idx)
self.batch_idx += 1
return next(self.dataloader_iters[self.dataloader_idx])
# 状态重置,以便之后可以再次迭代取数据。
self._reset()
# 迭代器在迭代终止时要抛出这个异常
raise StopIteration
def __len__(self):
# batch 数量就是所有 DataLoader 的 batch 总数
return self.cum_num_batches[-1]
def _reset(self):
if self.shuffle:
random.shuffle(self.dataloaders)
self._init()
@property
def batch_size(self):
# We assume that the batch size of each dataloader equals,
# so directly use the batch size of the first one dataloader.
return self.dataloaders[0].batch_size
ii). 每次迭代时从所有 DataLoader 中分别取1个 batch 拼接起来返回
这种实现方式下,有个关键点需要意识到:由于每次迭代都要从所有 DataLoader 中取数据,因此这就要求它们的“长度”(即 batch 数量)是一致的!于是,我们需要先确定哪个 DataLoader 是最长的,然后将其余 DataLoader pad 到等长。也因此,这么一搞,batch 数量就等于最长的那个 DataLoader 的长度。
还有一点可以考虑下,就是我们应该不希望同一种任务的样本都排列在一起,而是希望所有任务的样本在这个拼接起来的 batch 中是随机混在一起的,以致于不会造成“前面10个样本都是任务A的、接下来10个都是任务B的”这种情况。
也正是基于以上的这些考虑,CW 给这个 DataLoader 命名为 ‘MixupDataLoader’
class MixupDataLoader:
"""
A dataloader of Iterable type mixing up a batch yield from multiple dataloaders,
'mixup' means concat multiple batches and shuffle the data inside them.
Note:
The final batch size is sum of batch size of each dataloader.
"""
def __new__(cls, dataloaders: Iterable[DataLoader], *args, **kwargs):
dataloader_list = list(dataloaders)
# If only one dataloader, we will keep the original dataloader type
if len(dataloader_list) == 1:
return dataloader_list.pop()
# Otherwise, the truely 'MixupDataLoader' instance returns
else:
return super().__new__(cls)
def __init__(self, dataloaders: Iterable[DataLoader], accelerator: Accelerator = None) -> None:
dataloader_list = list(dataloaders)
if len(dataloader_list) > 1:
# The final batch size is sum of batch size of each dataloader.
self.batch_size = sum(dataloader.batch_size for dataloader in dataloader_list)
# Cuz accelerate only process Pytorch DataLoader, we should do this spliting operation here,
# instead of doing outside(i.e. in the main program flow)
if accelerator is not None:
dataloader_list = [accelerator.prepare_data_loader(dataloader) for dataloader in dataloader_list]
self._dataloaders = dataloader_list
# 确定哪个 DataLoader 最长
max_dataloader_len = 0
for i, dataloader in enumerate(dataloader_list):
if len(dataloader) > max_dataloader_len:
max_idx = i
max_dataloader_len = len(dataloader)
self._main_dataloader_idx = max_idx
self._num_batches = max_dataloader_len
# Beside the dataloader which has the most batches,
# other ones will be set as an cycling iterator,
# thus making them align with the dataloader that has the most batches.
# itertools.cycle() 会将可迭代对象设置为无限循环,然后利用 zip() 就可以将其余 DataLoader 与最长的那个对齐,
# (zip() 的性质就是将多个可迭代对象与“最短”的那个对齐)
self.data_chain = iter(
zip(*[
cycle(dataloader) if i != max_idx else dataloader
for i, dataloader in enumerate(dataloader_list)
])
)
def __len__(self):
return self._num_batches
def __iter__(self):
return self
def __next__(self):
try:
# A tuple consist of batches yield from each dataloader.
# 这里取出来的是由各个 DataLoader 的 batch 组成的元组
batch_tuple = next(self.data_chain)
device = batch_tuple[0]['input_ids'].device
# 由于我这里的每个 batch 是 dict,key 是 str, value 是 tensor,
# 而每种任务的 key 不一定是一样的,因此需要先将所有的 key 都统计出来。
data_keys = set(batch_tuple[0].keys())
for batch in batch_tuple:
data_keys.update(batch.keys())
# Padding batch for missing keys, cuz data keys of each dataloader may not the same.
# 如果任务 B 的 batch 没有任务 A 中的 batch 中的某个 key,那就对应“补齐”。
for k in data_keys:
for batch in batch_tuple:
if k not in batch:
if k == 'attention_mask':
batch[k] = torch.ones_like(batch['input_ids'])
elif 'id' in k:
batch[k] = torch.full(
(batch['input_ids'].size(0),), -1,
device=device, dtype=torch.long
)
# Leave for future
else:
pass
# We compute the sample counts in-time instead of using sum of batch size of each dataloader here
# incase of the 'batch padding' case.
num_samples = sum(batch['input_ids'].size(0) for batch in batch_tuple)
# Useful for mixing up & shuffling data inside each batch yield from each dataloader
# 生成随机索引以实现打乱(shuffle)的功能
random_indices = torch.randperm(num_samples, device=device)
# Concat the data generate from each dataloder
# 将不同任务的 batch 对应 key 下的 tensor 都“收集”到一起,构造出1个大 batch,这就是最终返回的形式。
mixup_batch = {}
for k in data_keys:
mixup_data = torch.empty(
(num_samples,) + batch_tuple[0][k].shape[1:],
device=device, dtype=batch_tuple[0][k].dtype
)
mixup_data[random_indices] = torch.cat(tuple(batch[k] for batch in batch_tuple))
mixup_batch[k] = mixup_data
return mixup_batch
except StopIteration:
# 状态重置,以便下个周期能够继续迭代取数据。
self._reset()
raise StopIteration
def _reset(self):
self.data_chain = iter(
zip(*[
cycle(dataloader) if i != self._main_dataloader_idx else dataloader
for i, dataloader in enumerate(self._dataloaders)
])
)
参考:CW