对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:
{
"input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?",
"output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}
由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
注意:
- 当前的数据处理场景设定如下:数据源为单个巨大的 jsonl 文件,可能包含多达上百万行数据记录。且单条数据条目普遍较为简短。本文聚焦于利用 PyTorch 的 Dataset 类的默认接口。在实践中,当然也可以重写 Dataset 类的 init 方法,使其仅加载诸如数据索引、数据条目总量这类元数据,而将具体数据内容的获取环节后置至 getitem 方法被调用时执行,来规避一次性将大规模数据集完整载入内存,而无需使用 iterable dataset。然而本文场景下,由于数据总量大,仅加载数据索引的元数据,内存占用还是很大,没有显著改善资源利用情况。与此同时,在该业务场景下,诸如 len 这类用以表征数据规模等属性的元信息并不重要。此时选用 iterable dataset 会更方便。
- 本文核心目的在于清晰阐释相关接口的高效运用方式,方便实操使用。而到了实际解决问题的时候思路可以很开阔。
基于IterableDataset的数据加载,代码实现如下:
import torch
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, data_file):
self.data_file = data_file
def __iter__(self):
return iter(self._load_data())