Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:

{
  "input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?",
  "output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}

由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于IterableDataset的数据加载,代码实现如下:

import torch
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                sample = process_line(line)
                yield sample

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

for batch in dataloader:
    # Train your model using the batch of data
    pass

在实际训练中还会遇到两个问题:

  1. 大模型一般需要使用多机多卡训练,需要避免多个进程中dataloader读取数据的竞争,并保证不同进程之间不会重复读取数据;
  2. 数据文件中某些行无法正确被解析,或者引用的外部资源找不到,导致process_line成员函数报错。数据集需要handle这类错误,防止因为报错中断训练。

以上问题对策如下:

  1. 在多机多卡的DDP训练中,可以使用DistributedSampler来处理多进程读数据的情形。DistributedSampler可以确保不同进程之间不会重复读取数据。具体的代码实现如下:
# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)

# Create a DistributedSampler
sampler = DistributedSampler(dataset)

# Create a DataLoader using the DistributedSampler
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)

for batch in dataloader:
    # Train your model using the batch of data
    pass
  1. 可以在调用process_line的时候试图handle一个错误,如果出错就跳过这条数据,改为(试图)获取下一条数据。具体的代码实现如下:
import torch
import logger
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                try:
                    sample = process_line(line)
                    yield sample
                except Exception as e:
                    # Print the detailed error information
                    logger.error(line)
                    logger.error(e)
                    pass

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

如果使用的是普通的Dataset,则参考以下代码,在__getitem__里面加入报错逻辑:

class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as file:
            for line in file:
                self.data.append(line)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        line = self.data[index]
        try:
            sample = self.process_line(line)
            return sample
        except Exception as e:
            # Print the detailed error information
            logger.error(line)
            logger.error(e)
            return self.__getitem__((index+1) % self.__len__())

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample
  • 11
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值