@dc.dataclass
class DataConfig(object):
train_file: Optional[str] = None
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST], # 可以写成["a", "b", "c"]只是一个标记的key罢了
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
def _load_datasets(
data_dir: str,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
) -> DatasetDict:
if data_format == '.jsonl':
dataset_dct = load_dataset(
data_dir,
data_files=data_files,
split=None,
num_proc=num_proc,
)
else:
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
data_dir,
data_config.data_format,
data_config.data_files, ## 这边其实是一个字典
self._num_proc,
)
def _get_dataset(self, split) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
self,
split,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
再数据导入后,可以使用get来进行获得对于对应key的value,即对应的dataset,然后再使用map函数来进行后续的数据处理,不同的数据集,例如train,val用不同的 map中指定的函数来进行数据处理
data = load_dataset(
'json',
data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
data_files='chat.json',
split='train[:80%]' # 加载特定的比例的数据
)
data = load_dataset(
'json',
data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
data_files='chat.json',
split=None
)
data
当你使用 split=None
来加载数据集时,load_dataset
函数会尝试加载整个数据集而不进行任何分割。在某些情况下,即使你没有明确指定数据集的分割方式,load_dataset
可能会自动识别并加载数据集的默认分割。这取决于数据集的结构以及它的文件命名或目录结构。
data = load_dataset(
'json',
data_dir='/home/qyj/code/train_llava/data/LLaVA-CC3M-Pretrain-595K',
data_files='chat.json',
split=Split.TRAIN # 加载特定的比例的数据
)
data
区别就是唯一的就是要通过一个字典来获得数据集,一个不需要。