数据准备及预处理

数据准备及预处理

数据准备

 def prepare_data(
 model_args: ModelArguments,
 data_args: DataTrainingArguments
 ) -> Dataset:
 ​
 def checksum(file_path, hash):
 with open(file_path, "rb") as datafile:
 binary_data = datafile.read()
 sha1 = hashlib.sha1(binary_data).hexdigest()
 if sha1 != hash:
 logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
 ​
 ext2type = {
 "csv": "csv",
 "json": "json",
 "jsonl": "json",
 "txt": "text"
 }
 ​
 max_samples = data_args.max_samples
 all_datasets: List[Dataset] = [] # support multiple datasets
 ​
 for dataset_attr in data_args.dataset_list:
 ​
 logger.info("Loading dataset {}...".format(dataset_attr))
 ​
 if dataset_attr.load_from == "hf_hub":
 data_path = dataset_attr.dataset_name
 data_files = None
 elif dataset_attr.load_from == "script":
 data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
 data_files = None
 elif dataset_attr.load_from == "file":
 data_path = None
 data_files: List[str] = []
 ​
 if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
 ​
 if data_path is None:
 data_path = ext2type.get(data_files[0].split(".")[-1], None)
 else:
 assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
 elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
 data_path = ext2type.get(data_files[0].split(".")[-1], None)
 else:
 raise ValueError("File not found.")
 ​
 assert data_path, "File extension must be txt, csv, json or jsonl."
 ​
 if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
 checksum(data_files[0], dataset_attr.dataset_sha1)
 else:
 logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
 else:
 raise NotImplementedError
 ​
 raw_datasets = load_dataset(
 data_path,
 data_files=data_files,
 cache_dir=model_args.cache_dir,
 use_auth_token=True if model_args.use_auth_token else None
 )
 dataset = raw_datasets[data_args.split]
 ​
 if max_samples is not None:
 max_samples_temp = min(len(dataset), max_samples)
 dataset = dataset.select(range(max_samples_temp))
 ​
 dummy_data = [None] * len(dataset)
 prefix_data = [dataset_attr.source_prefix] * len(dataset)
 for column_name, target_name in [
 ("prompt_column", "prompt"),
 ("query_column", "query"),
 ("response_column", "response"),
 ("history_column", "history")
 ]: # every dataset will have 4 columns same as each other
 if getattr(dataset_attr, column_name) != target_name:
 if getattr(dataset_attr, column_name):
 dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
 else: # None or empty string
 dataset = dataset.add_column(target_name, dummy_data)
 dataset = dataset.add_column("prefix", prefix_data)
 all_datasets.append(dataset)
 ​
 if len(data_args.dataset_list) == 1:
 all_datasets = all_datasets[0]
 else:
 all_datasets = concatenate_datasets(all_datasets)
 ​
 return all_datasets

预训练模型

 def preprocess_data(
 dataset: Dataset,
 tokenizer: PreTrainedTokenizer,
 data_args: DataTrainingArguments,
 training_args: Seq2SeqTrainingArguments,
 stage: Literal["pt", "sft", "rm", "ppo"]
 ) -> Dataset:
 ​
 column_names = list(dataset.column_names)
 prompt_template = Template(data_args.prompt_template)
 ​
 \# support question with a single answer or multiple answers
 def get_dialog(examples):
 for i in range(len(examples["prompt"])):
 if examples["prompt"][i] and examples["response"][i]:
 query, answer = examples["prompt"][i], examples["response"][i]
 query = query + "\n" + examples["query"][i] if examples["query"][i] else query
 prefix = examples["prefix"][i] if examples["prefix"][i] else ""
 dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
 yield dialog
 ​
 def preprocess_pretrain_dataset(examples):
 \# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
 text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
 concatenated_ids = list(chain(*text_ids))
 total_length = len(concatenated_ids)
 block_size = data_args.max_source_length - 1
 \# we drop the small remainder, and if the total_length < block_size, we exclude this batch
 total_length = (total_length // block_size) * block_size
 \# split by chunks of max_source_length
 result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
 for i in range(0, total_length, block_size)]
 return {
 "input_ids": result,
 "labels": result.copy()
 }
 ​
 def preprocess_supervised_dataset(examples):
 \# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 \# for input with history, we build multiple input-label pairs just like:
 \# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
 model_inputs = {"input_ids": [], "labels": []}
 max_length = data_args.max_source_length + data_args.max_target_length
 ​
 for dialog in get_dialog(examples):
 input_ids, labels = [], []
 ​
 for i in range(len(dialog) // 2):
 source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
 target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
 ​
 if len(source_ids) > data_args.max_source_length:
 source_ids = source_ids[:data_args.max_source_length]
 if len(target_ids) > data_args.max_target_length - 1: # eos token
 target_ids = target_ids[:data_args.max_target_length - 1]
 ​
 if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
 break
 ​
 input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
 labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
 ​
 model_inputs["input_ids"].append(input_ids)
 model_inputs["labels"].append(labels)
 ​
 return model_inputs
 ​
 def preprocess_unsupervised_dataset(examples):
 \# build inputs with format `<bos> X` and labels with format `<bos> Y`
 model_inputs = {"input_ids": [], "labels": []}
 ​
 for dialog in get_dialog(examples):
 prompt, answer = "".join(dialog[:-1]), dialog[-1]
 ​
 source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
 ​
 if len(source_ids) > data_args.max_source_length:
 source_ids = source_ids[:data_args.max_source_length]
 if len(target_ids) > data_args.max_target_length:
 target_ids = target_ids[:data_args.max_target_length]
 ​
 model_inputs["input_ids"].append(source_ids)
 model_inputs["labels"].append(target_ids)
 ​
 return model_inputs
 ​
 def preprocess_pairwise_dataset(examples):
 \# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
 model_inputs = {"accept_ids": [], "reject_ids": []}
 for dialog in get_dialog(examples):
 prompt, answer = "".join(dialog[:-1]), dialog[-1]
 ​
 source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
 reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
 ​
 if len(source_ids) > data_args.max_source_length:
 source_ids = source_ids[:data_args.max_source_length]
 if len(accept_ids) > data_args.max_target_length - 1: # eos token
 accept_ids = accept_ids[:data_args.max_target_length - 1]
 if len(reject_ids) > data_args.max_target_length - 1: # eos token
 reject_ids = reject_ids[:data_args.max_target_length - 1]
 ​
 accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
 reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
 ​
 model_inputs["accept_ids"].append(accept_ids)
 model_inputs["reject_ids"].append(reject_ids)
 return model_inputs
 ​
 def print_supervised_dataset_example(example):
 print("input_ids:\n{}".format(example["input_ids"]))
 print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 print("label_ids:\n{}".format(example["labels"]))
 print("labels:\n{}".format(
 tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
 skip_special_tokens=False)
 ))
 ​
 def print_pairwise_dataset_example(example):
 print("accept_ids:\n{}".format(example["accept_ids"]))
 print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
 print("reject_ids:\n{}".format(example["reject_ids"]))
 print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
 ​
 def print_unsupervised_dataset_example(example):
 print("input_ids:\n{}".format(example["input_ids"]))
 print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 ​
 if stage == "pt":
 preprocess_function = preprocess_pretrain_dataset
 elif stage == "sft":
 preprocess_function = preprocess_unsupervised_dataset \
 if training_args.predict_with_generate else preprocess_supervised_dataset
 elif stage == "rm":
 preprocess_function = preprocess_pairwise_dataset
 elif stage == "ppo":
 preprocess_function = preprocess_unsupervised_dataset
 ​
 with training_args.main_process_first(desc="dataset map pre-processing"):
 dataset = dataset.map(
 preprocess_function,
 batched=True,
 num_proc=data_args.preprocessing_num_workers,
 remove_columns=column_names,
 load_from_cache_file=not data_args.overwrite_cache,
 desc="Running tokenizer on dataset"
 )
 ​
 if stage == "pt":
 print_unsupervised_dataset_example(dataset[0])
 elif stage == "sft":
 print_supervised_dataset_example(dataset[0])
 elif stage == "rm":
 print_pairwise_dataset_example(dataset[0])
 elif stage == "ppo":
 print_unsupervised_dataset_example(dataset[0])
 ​
 return dataset

1. 初始化

  • 获取数据集的列名。

  • 使用 Template 创建一个提示模板对象。

2. 辅助函数

get_dialog(examples)
  • 处理输入示例,生成对话数据。支持包含单个或多个答案的问题。

3. 数据预处理函数

preprocess_pretrain_dataset(examples)
  • 针对预训练阶段 (pt),将输入文本构建为 bos (开始标记) 开头的一组分段文本,忽略 eos (结束标记)。

  • 将所有文本连接起来并按块大小分割。

preprocess_supervised_dataset(examples)
  • 针对监督微调阶段 (sft),构建格式为 bos eos的输入和格式为 eos的标签。

  • 处理包含历史的输入,生成多个输入-标签对。

preprocess_unsupervised_dataset(examples)
  • 针对非监督微调阶段,构建格式为 <bos> X 的输入和 <bos> Y 的标签。

preprocess_pairwise_dataset(examples)
  • 针对排序模型阶段 (rm),构建格式为 <bos> X Y1 <eos><bos> X Y2 <eos> 的输入对。

4. 打印示例函数

  • print_supervised_dataset_example(example):打印监督数据集示例。

  • print_pairwise_dataset_example(example):打印排序数据集示例。

  • print_unsupervised_dataset_example(example):打印非监督数据集示例。

5. 主流程

  • 根据不同阶段选择合适的预处理函数。

  • 使用 dataset.map 方法对数据集进行预处理,支持批处理、并行处理和缓存。

  • 打印预处理后的示例。

  • 22
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值