load_dataset能一次性导入train, val, test

@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    val_file: Optional[str] = None
    test_file: Optional[str] = None
    num_proc: Optional[int] = None

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST], # 可以写成["a", "b", "c"]只是一个标记的key罢了
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }

def _load_datasets(
        data_dir: str,
        data_format: str,
        data_files: dict[NamedSplit, str],
        num_proc: Optional[int],
) -> DatasetDict:
    if data_format == '.jsonl':
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct


class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files, ## 这边其实是一个字典
            self._num_proc,
        )

    def _get_dataset(self, split) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)

    def get_dataset(
            self,
            split,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )

train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        combine=ft_config.combine,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)

再数据导入后,可以使用get来进行获得对于对应key的value,即对应的dataset,然后再使用map函数来进行后续的数据处理,不同的数据集,例如train,val用不同的 map中指定的函数来进行数据处理

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split='train[:80%]' # 加载特定的比例的数据
)

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split=None 
)
data

 

当你使用 split=None 来加载数据集时,load_dataset 函数会尝试加载整个数据集而不进行任何分割。在某些情况下,即使你没有明确指定数据集的分割方式,load_dataset 可能会自动识别并加载数据集的默认分割。这取决于数据集的结构以及它的文件命名或目录结构。

data = load_dataset(
    'json',
    data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
    data_files='chat.json',
    split=Split.TRAIN # 加载特定的比例的数据
)
data

区别就是唯一的就是要通过一个字典来获得数据集,一个不需要。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值