项目实训6-以PT部分为例分析前置读取代码和train_pt的部分实现

预训练(Pre-Trainin) 是指在大规模无标注数据集上对模型进行初步训练,使模型学习到广泛的语言特征和结构。

功能

  • 捕捉基础特征:预训练阶段模型通过大规模文本数据学习语言的基本特征,如词汇、语法和上下文关系。
  • 迁移学习的基础:预训练后的模型可以用于多种下游任务,通过微调(Fine-Tuning)进一步优化特定任务的性能。
  • 减少标注数据需求:通过预训练,模型已经掌握了广泛的语言知识,微调时只需要较少的标注数据即可达到较好的效果。

初步代码实现:

# coding=utf-8
# Implements several parameter-efficient pre-training method.
# This code is inspired by
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py


import math

from utils import (
    DynamicDataCollatorWithPadding,
    PeftTrainer,
    LogCallback,
    load_pretrained,
    prepare_args,
    prepare_data,
    preprocess_data,
    plot_loss
)


def main():

    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
    dataset = prepare_data(model_args, data_args)

    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
    data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)

    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

    # Initialize our Trainer
    trainer = PeftTrainer(
        finetuning_args=finetuning_args,
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback()],
        **trainer_kwargs
    )

    # Training
    if training_args.do_train:
        train_result = trainer.train()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        trainer.save_model()
        if trainer.is_world_process_zero() and model_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval")

        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

前置读取分析:

由于后续的代码都需要从执行主要函数的前5行的类似代码

model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
dataset = prepare_data(model_args, data_args)

model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)

因此在此处先行进行解释,后续的部分不再解释这一部分的代码:

prepare_args

首先通过 utils.commom中的prepare_args方法来读取模型、数据、训练、微调的一些参数

def prepare_args(
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))

    # 使用 HfArgumentParser 解析参数,可以从 JSON 文件或命令行获取参数。
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
        model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()

    # Setup logging
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
    # 检查各种参数的合法性,确保参数设置符合要求。
    data_args.init_for_training()

    assert stage == "sft" or (not training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."

    assert not (training_args.do_train and training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True while training."

    assert (not training_args.do_predict) or training_args.predict_with_generate, \
        "Please enable `predict_with_generate` to save model predictions."

    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
        "Quantization is only compatible with the LoRA method."

    if model_args.checkpoint_dir is not None:
        if finetuning_args.finetuning_type != "lora":
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
        else:
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
                "Quantized model only accepts a single checkpoint."

    if model_args.quantization_bit is not None and (not training_args.do_train):
        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")

    if training_args.do_train and (not training_args.fp16):
        logger.warning("We recommend enable fp16 mixed precision training.")

    if data_args.prompt_template == "default":
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")

    if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
        logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
        training_args.ddp_find_unused_parameters = False

    training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

    # 根据优化器设置和量化位数调整训练参数。
    if model_args.quantization_bit is not None:
        if training_args.fp16:
            model_args.compute_dtype = torch.float16
        elif training_args.bf16:
            model_args.compute_dtype = torch.bfloat16
        else:
            model_args.compute_dtype = torch.float32

    # Log on each process the small summary:
    logger.info(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
        + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args

prepare_args 函数根据训练阶段准备和检查训练参数,并设置日志记录和警告。

该函数解析命令行或 JSON 文件中的参数,确保参数的合法性,并进行必要的配置和建议。

最终返回包含模型参数、数据训练参数、序列到序列训练参数和微调参数的元组。

prepare_data

随后通过utils.common 中的 prepare_data方法来读取基本的数据

def prepare_data(
        model_args: ModelArguments,
        data_args: DataTrainingArguments
) -> Dataset:

    # 定义一个 checksum 辅助函数,用于校验文件的 SHA-1 哈希值是否匹配。如果不匹配,记录一个警告日志。
    def checksum(file_path, hash):
        with open(file_path, "rb") as datafile:
            binary_data = datafile.read()
        sha1 = hashlib.sha1(binary_data).hexdigest()
        if sha1 != hash:
            logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))

    # 定义一个字典 ext2type,将文件扩展名映射到相应的数据类型。
    ext2type = {
        "csv": "csv",
        "json": "json",
        "jsonl": "json",
        "txt": "text"
    }

    # 获取 max_samples 参数,用于限制加载的数据样本数量。
	# 定义一个列表 all_datasets,用于存储加载的所有数据集。
    max_samples = data_args.max_samples
    all_datasets: List[Dataset] = [] # support multiple datasets

    # 遍历 data_args.dataset_list 中的数据集属性,依次加载每个数据集。
	# 记录加载数据集的日志信息。
    for dataset_attr in data_args.dataset_list:

        logger.info("Loading dataset {}...".format(dataset_attr))
		
        '''
        根据 dataset_attr.load_from 的值,决定如何加载数据集:
        如果数据集来自 hf_hub,直接使用数据集名称。
        如果数据集来自 script,构建数据集路径。
        如果数据集来自文件系统,根据文件类型和数量设置 data_path 和 data_files。
        如果需要,进行文件的 SHA-1 校验。
        '''
        if dataset_attr.load_from == "hf_hub":
            data_path = dataset_attr.dataset_name
            data_files = None
        elif dataset_attr.load_from == "script":
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
            data_files = None
        elif dataset_attr.load_from == "file":
            data_path = None
            data_files: List[str] = []

            if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                    data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))

                    if data_path is None:
                        data_path = ext2type.get(data_files[0].split(".")[-1], None)
                    else:
                        assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
            elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
                data_path = ext2type.get(data_files[0].split(".")[-1], None)
            else:
                raise ValueError("File not found.")

            assert data_path, "File extension must be txt, csv, json or jsonl."

            if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
                checksum(data_files[0], dataset_attr.dataset_sha1)
            else:
                logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
        else:
            raise NotImplementedError

        '''
        使用 load_dataset 函数加载数据集。
        根据 data_args.split 获取指定的分割数据集。
        如果设置了 max_samples,限制加载的数据样本数量。
        为数据集添加列,确保所有数据集具有一致的列名。
        将处理后的数据集添加到 all_datasets 列表中。
        '''
        raw_datasets = load_dataset(
            data_path,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None
        )
        dataset = raw_datasets[data_args.split]

        if max_samples is not None:
            max_samples_temp = min(len(dataset), max_samples)
            dataset = dataset.select(range(max_samples_temp))

        dummy_data = [None] * len(dataset)
        prefix_data = [dataset_attr.source_prefix] * len(dataset)
        for column_name, target_name in [
            ("prompt_column", "prompt"),
            ("query_column", "query"),
            ("response_column", "response"),
            ("history_column", "history")
        ]: # every dataset will have 4 columns same as each other
            if getattr(dataset_attr, column_name) != target_name:
                if getattr(dataset_attr, column_name):
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
                else: # None or empty string
                    dataset = dataset.add_column(target_name, dummy_data)
        dataset = dataset.add_column("prefix", prefix_data)
        all_datasets.append(dataset)

    '''
    如果只有一个数据集,直接返回该数据集。
	如果有多个数据集,将它们合并后返回。
    '''
    if len(data_args.dataset_list) == 1:
        all_datasets = all_datasets[0]
    else:
        all_datasets = concatenate_datasets(all_datasets)

    return all_datasets
load_pretrained

函数用于加载预训练的模型和分词器,并根据不同的训练或推理阶段进行相应的配置。它支持模型量化、自动类注册和适配器初始化,并在需要时处理奖励模型和近端策略优化。函数最终返回配置好的模型和分词器。

def load_pretrained(
        model_args: ModelArguments,
        finetuning_args: FinetuningArguments,
        is_trainable: Optional[bool] = False,
        stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
    r"""
    Loads pretrained model and tokenizer.
    
        stage:当前阶段,可以是 "pt"(预训练)、"sft"(监督微调)、"rm"(奖励模型)或 "ppo"(近端策略优化)。
        检查 is_trainable 和 checkpoint_dir 参数,并根据情况设置 finetuning_args。
        验证阶段是否合适,确保 RM 和 PPO 阶段只能与 LoRA 方法一起使用。
        
    Support both training and inference.
    """
    if (not is_trainable) and model_args.checkpoint_dir is None:
        logger.warning("Checkpoint is not found at evaluation, load the original model.")
        finetuning_args = FinetuningArguments(finetuning_type="none")

    assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
        "RM and PPO training can only be performed with the LoRA method."

    # 设置加载配置的参数。从预训练模型路径加载分词器,并设置填充标记 ID。
    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=model_args.use_fast_tokenizer,
        padding_side="left",
        **config_kwargs
    )
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
        tokenizer.pad_token_id = 0 # set as the <unk> token

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    is_mergeable = True

    # Quantization configurations (using bitsandbytes library).
    # 加载模型配置。根据量化位数设置相应的量化配置,使用 bitsandbytes 库进行 4 位或 8 位量化。设置 device_map,确保模型在正确的设备上加载。
    if model_args.quantization_bit is not None:
        if model_args.quantization_bit == 8:
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
            config_kwargs["load_in_8bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0
            )

        elif model_args.quantization_bit == 4:
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
            require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
            require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
            require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
            config_kwargs["load_in_4bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=model_args.compute_dtype,
                bnb_4bit_use_double_quant=model_args.double_quantization,
                bnb_4bit_quant_type=model_args.quantization_type
            )

        is_mergeable = False
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

    # 如果不是可训练模式,设置 device_map 为 "auto" 以自动分配设备。根据是否有检查点目录选择要加载的模型路径。加载预训练模型,设置数据类型和低 CPU 内存使用。
    if not is_trainable: # `device_map=auto` should be used for inference only
        config_kwargs["device_map"] = "auto"

    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
        model_to_load = model_args.checkpoint_dir[0]
    else:
        model_to_load = model_args.model_name_or_path

    # Load and prepare pretrained models (without valuehead).
    model = AutoModelForCausalLM.from_pretrained(
        model_to_load,
        config=config,
        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
        low_cpu_mem_usage=True,
        **config_kwargs
    )

    # Register auto class to save the custom code files.
    # 注册自动类以保存自定义代码文件。初始化模型适配器,根据是否为可训练模式进行相应配置
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
        config.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
        tokenizer.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
        model.__class__.register_for_auto_class()

    # Initialize adapters
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
    model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

    # 如果阶段是 RM 或 PPO,添加值头以评估奖励。加载奖励模型权重并进行相应配置。
    if stage == "rm" or stage == "ppo": # add value head
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
            if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })

        if stage == "ppo": # load reward model
            assert is_trainable, "PPO stage cannot be performed at evaluation."
            assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
            logger.info("Load reward model from {}".format(model_args.reward_model))
            model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
            assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."

    # 如果不是可训练模式,固定所有模型参数,并根据需要将模型数据类型转换为 fp16。打印可训练参数信息。返回加载的模型和分词器。
    if not is_trainable:
        model.requires_grad_(False) # fix all model params
        model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16

    print_trainable_params(model)

    return model, tokenizer
preprocess_data

preprocess_data 函数根据不同的训练阶段对数据集进行预处理。它定义了多个预处理函数,处理预训练数据、监督数据、无监督数据和成对数据。根据不同的阶段选择合适的预处理函数,并使用 dataset.map 对数据集进行批处理预处理。预处理完成后,打印示例并返回处理好的数据集。

def preprocess_data(
        dataset: Dataset,
        tokenizer: PreTrainedTokenizer,
        data_args: DataTrainingArguments,
        training_args: Seq2SeqTrainingArguments,
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:

    column_names = list(dataset.column_names)
    prompt_template = Template(data_args.prompt_template)

    # support question with a single answer or multiple answers 定义一个 get_dialog 函数,用于根据示例生成对话内容。
    def get_dialog(examples):
        for i in range(len(examples["prompt"])):
            if examples["prompt"][i] and examples["response"][i]:
                query, answer = examples["prompt"][i], examples["response"][i]
                query = query + "\n" + examples["query"][i] if examples["query"][i] else query
                prefix = examples["prefix"][i] if examples["prefix"][i] else ""
                dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
                yield dialog

    # 定义一个 preprocess_pretrain_dataset 函数,用于预训练数据集的预处理。将文本拼接成指定长度的块,并生成输入和标签。
    def preprocess_pretrain_dataset(examples):
        # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
        text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
        concatenated_ids = list(chain(*text_ids))
        total_length = len(concatenated_ids)
        block_size = data_args.max_source_length - 1
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
        total_length = (total_length // block_size) * block_size
        # split by chunks of max_source_length
        result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
                  for i in range(0, total_length, block_size)]
        return {
            "input_ids": result,
            "labels": result.copy()
        }

    def preprocess_supervised_dataset(examples):
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
        # for input with history, we build multiple input-label pairs just like:
        # 定义一个 preprocess_supervised_dataset 函数,用于监督数据集的预处理。生成格式为 <bos> X Y <eos> 的输入和 <ignore> ... <ignore> Y <eos> 的标签。
        # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
        model_inputs = {"input_ids": [], "labels": []}
        max_length = data_args.max_source_length + data_args.max_target_length

        for dialog in get_dialog(examples):
            input_ids, labels = [], []

            for i in range(len(dialog) // 2):
                source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
                target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)

                if len(source_ids) > data_args.max_source_length:
                    source_ids = source_ids[:data_args.max_source_length]
                if len(target_ids) > data_args.max_target_length - 1: # eos token
                    target_ids = target_ids[:data_args.max_target_length - 1]

                if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
                    break

                input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
                labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]

            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)

        return model_inputs

    def preprocess_unsupervised_dataset(examples):
        # build inputs with format `<bos> X` and labels with format `<bos> Y`
        # 定义一个 preprocess_unsupervised_dataset 函数,用于无监督数据集的预处理。生成格式为 <bos> X 的输入和 <bos> Y 的标签。
        model_inputs = {"input_ids": [], "labels": []}

        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            target_ids = tokenizer.encode(text=answer, add_special_tokens=True)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(target_ids) > data_args.max_target_length:
                target_ids = target_ids[:data_args.max_target_length]

            model_inputs["input_ids"].append(source_ids)
            model_inputs["labels"].append(target_ids)

        return model_inputs

    def preprocess_pairwise_dataset(examples):
        # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
        # 定义一个 preprocess_pairwise_dataset 函数,用于成对数据集的预处理。生成格式为 <bos> X Y1 <eos> 和 <bos> X Y2 <eos> 的输入对。
        model_inputs = {"accept_ids": [], "reject_ids": []}
        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
            reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(accept_ids) > data_args.max_target_length - 1: # eos token
                accept_ids = accept_ids[:data_args.max_target_length - 1]
            if len(reject_ids) > data_args.max_target_length - 1: # eos token
                reject_ids = reject_ids[:data_args.max_target_length - 1]

            accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
            reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]

            model_inputs["accept_ids"].append(accept_ids)
            model_inputs["reject_ids"].append(reject_ids)
        return model_inputs

    def print_supervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
        print("label_ids:\n{}".format(example["labels"]))
        print("labels:\n{}".format(
            tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
                             skip_special_tokens=False)
        ))

    def print_pairwise_dataset_example(example):
        print("accept_ids:\n{}".format(example["accept_ids"]))
        print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
        print("reject_ids:\n{}".format(example["reject_ids"]))
        print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))

    def print_unsupervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

    if stage == "pt":
        preprocess_function = preprocess_pretrain_dataset
    elif stage == "sft":
        preprocess_function = preprocess_unsupervised_dataset \
            if training_args.predict_with_generate else preprocess_supervised_dataset
    elif stage == "rm":
        preprocess_function = preprocess_pairwise_dataset
    elif stage == "ppo":
        preprocess_function = preprocess_unsupervised_dataset

    # 使用 dataset.map 函数对数据集进行预处理。根据阶段选择并打印预处理后的示例。返回预处理后的数据集。
    with training_args.main_process_first(desc="dataset map pre-processing"):
        dataset = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset"
        )

        if stage == "pt":
            print_unsupervised_dataset_example(dataset[0])
        elif stage == "sft":
            print_supervised_dataset_example(dataset[0])
        elif stage == "rm":
            print_pairwise_dataset_example(dataset[0])
        elif stage == "ppo":
            print_unsupervised_dataset_example(dataset[0])

        return dataset
DynamicDataCollatorWithPadding

DynamicDataCollatorWithPadding 类通过继承 DataCollatorWithPadding,增加了动态填充和生成注意力掩码的功能,特别适用于批处理数据的填充和序列处理。它能够处理包含 input_idslabels 的数据,确保在训练和评估时使用左侧填充,并生成相应的注意力掩码来标识填充部分。

综合考虑,这部分的实现如下功能:

准备预训练模型和数据集参数(从命令行或 JSON 文件中获取参数、配置日志记录、检查参数的合法性、根据不同的阶段设置相应的训练参数)

准备数据集(根据数据参数中的数据集列表,加载指定的数据集。支持从 Hugging Face Hub、脚本或本地文件系统加载数据集。检查并验证数据集文件的完整性。加载数据集,并根据需要限制样本数量。统一数据集的列名,确保数据集具有一致的结构。)

加载预训练模型和分词器(设置加载配置的参数。从预训练模型路径加载分词器,并设置填充标记 ID。根据量化位数设置相应的量化配置,使用 bitsandbytes 库进行 4 位或 8 位量化。根据是否有检查点目录选择要加载的模型路径。加载预训练模型,设置数据类型和低 CPU 内存使用。注册自动类以保存自定义代码文件。初始化模型适配器,根据是否为可训练模式进行相应配置。处理 RM 和 PPO 阶段时,添加值头以评估奖励。如果不是可训练模式,固定所有模型参数,并根据需要将模型数据类型转换为 fp16。)

预处理数据集(根据不同的训练阶段定义不同的预处理函数:预训练数据集预处理函数(preprocess_pretrain_dataset)。监督数据集预处理函数(preprocess_supervised_dataset)。无监督数据集预处理函数(preprocess_unsupervised_dataset)。成对数据集预处理函数(preprocess_pairwise_dataset)。选择合适的预处理函数,并使用 dataset.map 对数据集进行批处理预处理。打印预处理后的数据集示例(用于调试和验证)。返回预处理后的数据集。)

准备数据填充器(继承自 DataCollatorWithPadding,增加了动态填充和生成注意力掩码的功能。根据是否忽略填充标记计算损失,设置标签填充值。在批处理数据时,为序列生成注意力掩码,并填充到批次中最长的序列长度。)

后置训练部分

后续的训练过程首先是一个简单的利用 python.dataset.train_test_split方法来进行的交叉验证分割。

    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

随后便是定义模型,训练模型,保存模型,评估模型的经典步骤,其中定义模型该阶段采用的我们在 peft_trainer 中实现的集成于python.transformers中的Seq2SeqTrainer 模型。

class PeftTrainer(Seq2SeqTrainer):
    r"""
    Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
    继承自Seq2SeqTrainer,实现了自定义模型保存和加载功能
    """

    def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
        if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
            logger.warning("Previous log file in this folder will be deleted.")
            os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))

    def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
        r"""
        Saves trainable parameters as model checkpoint.

        This function will only be executed at the process zero.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.

        保存训练参数作为模型检查点。该方法只在主进程中执行。
        创建输出目录并确保其存在。
        通过 unwrap_model 获取基础模型。如果模型有 pretrained_model 属性(如在使用 LoRA 时),则获取它并保存 v_head 的状态字典。
        根据微调类型(lora 或其他)保存模型:
        如果是 LoRA 微调,调用 save_pretrained 方法保存基础模型的参数。
        如果是全参数或冻结微调,临时设置 use_cache 属性为 True,保存模型后再将其恢复为 False。
        如果有 tokenizer,则将其保存到输出目录。
        将训练参数和微调参数分别保存到 TRAINING_ARGS_NAME 和 FINETUNING_ARGS_NAME 文件中。
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")
        model = unwrap_model(self.model)

        if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
            backbone_model = getattr(model, "pretrained_model")
            torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
        else:
            backbone_model = model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
        else: # freeze/full tuning
            backbone_model.config.use_cache = True
            backbone_model.save_pretrained(
                output_dir,
                state_dict=get_state_dict(backbone_model),
                safe_serialization=self.args.save_safetensors
            )
            backbone_model.config.use_cache = False
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)

        with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
            f.write(self.args.to_json_string() + "\n")
        self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))

    def _load_best_model(self):
        r"""
        Loads trainable parameters from model checkpoint.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.

        从最佳检查点加载可训练参数。该方法只在主进程中执行。
        记录加载最佳模型的日志信息。
        通过 unwrap_model 获取基础模型。如果模型有 pretrained_model 属性,则获取它作为 backbone_model。
        根据微调类型(lora 或其他)加载模型:
        如果是 LoRA 微调,加载适配器并设置 v_head 参数。
        如果是全参数或冻结微调,调用 load_trainable_params 函数加载可训练参数。
        """
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

        model = unwrap_model(self.model)
        backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
            if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })
        else: # freeze/full-tuning
            load_trainable_params(backbone_model, self.state.best_model_checkpoint)

其中有关LoRA相关的已经在之前的博客中提及。

其中不继承的函数 _load_best_model _save已经在上述代码中阐述。相关继承的方向,Seq2SeqTrainer 是 Hugging Face transformers 库中用于训练序列到序列(Seq2Seq)模型的类。它提供了一些方法和属性,方便进行模型训练和评估。以下是 Seq2SeqTrainer 的一些关键功能:

  • 训练和评估:提供训练和评估的核心方法,如 trainevaluate
  • 保存和加载模型:提供保存和加载模型检查点的方法,如 save_modelload_model
  • 分布式训练:支持多 GPU 和分布式训练。
  • 混合精度训练:支持使用 fp16bf16 进行混合精度训练。
  • 日志记录和监控:集成了日志记录和监控功能,支持 TensorBoard 等工具。

PeftTrainer 通过继承 Seq2SeqTrainer,扩展了其功能,支持参数高效的检查点保存和加载。它在初始化时进行一些必要的设置,并重写了 _save_load_best_model 方法,以支持 LoRA 等微调方法。通过这些扩展,PeftTrainer 可以更高效地管理模型的训练和微调过程。

实际调用中,这个训练器的核心功能是实现参数高效的预训练(pre-training),通过指定 stage=“pt” 来标明这是预训练阶段。代码涵盖了从数据准备、模型加载、数据预处理、训练、评估到记录日志的完整流程,特别适用于需要大量计算资源的语言建模任务。通过参数高效的方法,这个训练器能够在保证性能的前提下,显著减少训练所需的资源和时间。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值