数据准备及预处理
数据准备
def prepare_data( model_args: ModelArguments, data_args: DataTrainingArguments ) -> Dataset: def checksum(file_path, hash): with open(file_path, "rb") as datafile: binary_data = datafile.read() sha1 = hashlib.sha1(binary_data).hexdigest() if sha1 != hash: logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) ext2type = { "csv": "csv", "json": "json", "jsonl": "json", "txt": "text" } max_samples = data_args.max_samples all_datasets: List[Dataset] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) if dataset_attr.load_from == "hf_hub": data_path = dataset_attr.dataset_name data_files = None elif dataset_attr.load_from == "script": data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_files = None elif dataset_attr.load_from == "file": data_path = None data_files: List[str] = [] if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) if data_path is None: data_path = ext2type.get(data_files[0].split(".")[-1], None) else: assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_path = ext2type.get(data_files[0].split(".")[-1], None) else: raise ValueError("File not found.") assert data_path, "File extension must be txt, csv, json or jsonl." if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: checksum(data_files[0], dataset_attr.dataset_sha1) else: logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") else: raise NotImplementedError raw_datasets = load_dataset( data_path, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None ) dataset = raw_datasets[data_args.split] if max_samples is not None: max_samples_temp = min(len(dataset), max_samples) dataset = dataset.select(range(max_samples_temp)) dummy_data = [None] * len(dataset) prefix_data = [dataset_attr.source_prefix] * len(dataset) for column_name, target_name in [ ("prompt_column", "prompt"), ("query_column", "query"), ("response_column", "response"), ("history_column", "history") ]: # every dataset will have 4 columns same as each other if getattr(dataset_attr, column_name) != target_name: if getattr(dataset_attr, column_name): dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) else: # None or empty string dataset = dataset.add_column(target_name, dummy_data) dataset = dataset.add_column("prefix", prefix_data) all_datasets.append(dataset) if len(data_args.dataset_list) == 1: all_datasets = all_datasets[0] else: all_datasets = concatenate_datasets(all_datasets) return all_datasets
预训练模型
def preprocess_data( dataset: Dataset, tokenizer: PreTrainedTokenizer, data_args: DataTrainingArguments, training_args: Seq2SeqTrainingArguments, stage: Literal["pt", "sft", "rm", "ppo"] ) -> Dataset: column_names = list(dataset.column_names) prompt_template = Template(data_args.prompt_template) \# support question with a single answer or multiple answers def get_dialog(examples): for i in range(len(examples["prompt"])): if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if examples["query"][i] else query prefix = examples["prefix"][i] if examples["prefix"][i] else "" dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) yield dialog def preprocess_pretrain_dataset(examples): \# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>) text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] concatenated_ids = list(chain(*text_ids)) total_length = len(concatenated_ids) block_size = data_args.max_source_length - 1 \# we drop the small remainder, and if the total_length < block_size, we exclude this batch total_length = (total_length // block_size) * block_size \# split by chunks of max_source_length result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] for i in range(0, total_length, block_size)] return { "input_ids": result, "labels": result.copy() } def preprocess_supervised_dataset(examples): \# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` \# for input with history, we build multiple input-label pairs just like: \# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 model_inputs = {"input_ids": [], "labels": []} max_length = data_args.max_source_length + data_args.max_target_length for dialog in get_dialog(examples): input_ids, labels = [], [] for i in range(len(dialog) // 2): source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token target_ids = target_ids[:data_args.max_target_length - 1] if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: break input_ids += source_ids + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels) return model_inputs def preprocess_unsupervised_dataset(examples): \# build inputs with format `<bos> X` and labels with format `<bos> Y` model_inputs = {"input_ids": [], "labels": []} for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) target_ids = tokenizer.encode(text=answer, add_special_tokens=True) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length: target_ids = target_ids[:data_args.max_target_length] model_inputs["input_ids"].append(source_ids) model_inputs["labels"].append(target_ids) return model_inputs def preprocess_pairwise_dataset(examples): \# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` model_inputs = {"accept_ids": [], "reject_ids": []} for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(accept_ids) > data_args.max_target_length - 1: # eos token accept_ids = accept_ids[:data_args.max_target_length - 1] if len(reject_ids) > data_args.max_target_length - 1: # eos token reject_ids = reject_ids[:data_args.max_target_length - 1] accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] model_inputs["accept_ids"].append(accept_ids) model_inputs["reject_ids"].append(reject_ids) return model_inputs def print_supervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) print("labels:\n{}".format( tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], skip_special_tokens=False) )) def print_pairwise_dataset_example(example): print("accept_ids:\n{}".format(example["accept_ids"])) print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) print("reject_ids:\n{}".format(example["reject_ids"])) print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) def print_unsupervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) if stage == "pt": preprocess_function = preprocess_pretrain_dataset elif stage == "sft": preprocess_function = preprocess_unsupervised_dataset \ if training_args.predict_with_generate else preprocess_supervised_dataset elif stage == "rm": preprocess_function = preprocess_pairwise_dataset elif stage == "ppo": preprocess_function = preprocess_unsupervised_dataset with training_args.main_process_first(desc="dataset map pre-processing"): dataset = dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset" ) if stage == "pt": print_unsupervised_dataset_example(dataset[0]) elif stage == "sft": print_supervised_dataset_example(dataset[0]) elif stage == "rm": print_pairwise_dataset_example(dataset[0]) elif stage == "ppo": print_unsupervised_dataset_example(dataset[0]) return dataset
1. 初始化
-
获取数据集的列名。
-
使用
Template
创建一个提示模板对象。
2. 辅助函数
get_dialog(examples)
-
处理输入示例,生成对话数据。支持包含单个或多个答案的问题。
3. 数据预处理函数
preprocess_pretrain_dataset(examples)
-
针对预训练阶段 (
pt
),将输入文本构建为bos
(开始标记) 开头的一组分段文本,忽略eos
(结束标记)。 -
将所有文本连接起来并按块大小分割。
preprocess_supervised_dataset(examples)
-
针对监督微调阶段 (
sft
),构建格式为bos eos
的输入和格式为eos
的标签。 -
处理包含历史的输入,生成多个输入-标签对。
preprocess_unsupervised_dataset(examples)
-
针对非监督微调阶段,构建格式为
<bos> X
的输入和<bos> Y
的标签。
preprocess_pairwise_dataset(examples)
-
针对排序模型阶段 (
rm
),构建格式为<bos> X Y1 <eos>
和<bos> X Y2 <eos>
的输入对。
4. 打印示例函数
-
print_supervised_dataset_example(example)
:打印监督数据集示例。 -
print_pairwise_dataset_example(example)
:打印排序数据集示例。 -
print_unsupervised_dataset_example(example)
:打印非监督数据集示例。
5. 主流程
-
根据不同阶段选择合适的预处理函数。
-
使用
dataset.map
方法对数据集进行预处理,支持批处理、并行处理和缓存。 -
打印预处理后的示例。