项目实训5-训练相关工具实现(后补)

项目实训-训练相关工具实现

adaptive_ntk_init 适应不同设备或不同配置的网络

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def adaptive_ntk_init(self, dim, max_position_embeddings=4096, base=10000, device=None):
    self.dim = dim
    self.base = base
    old_init(self, dim, max_position_embeddings, base, device)

通过一个示例展示了如何在类方法中覆盖或扩展现有的初始化方法。这里具体使用了 transformers 库中的 LlamaRotaryEmbedding 类的初始化方法。下面是详细的解释:

  1. 引用原始的初始化方法:
    • old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
    • 这行代码创建了一个名为 old_init 的变量,指向 LlamaRotaryEmbedding 类的构造方法(__init__)。这允许在新的初始化函数中重用原始的初始化逻辑。
  2. 定义新的初始化方法:
    • def adaptive_ntk_init(self, dim, max_position_embeddings=4096, base=10000, device=None):
    • 这是一个类的方法,它在 self(类的实例)上调用,并接受几个参数,包括 dimmax_position_embeddingsbase,和 device。这些参数对应于 LlamaRotaryEmbedding 类原始构造函数的参数。
  3. 设置类的属性:
    • self.dim = dim
    • self.base = base
    • 这两行代码将方法的参数 dimbase 分别保存到类实例的属性中。这允许在类的其他方法中访问这些值。
  4. 调用原始的初始化方法:
    • old_init(self, dim, max_position_embeddings, base, device)
    • 这行代码调用先前引用的 old_init,即 LlamaRotaryEmbedding 的构造方法。通过这种方式,它确保类的所有基本初始化步骤都被执行,同时也加入了新的初始化逻辑。

此代码的目的是在不改变原始 LlamaRotaryEmbedding 类的情况下,添加或修改初始化过程中的某些行为。adaptive_ntk_init 可能是用于特定场景,需要调整 dimbase 参数并保留原始类初始化的行为。通过这种方式,可以灵活地在保留原始功能的同时扩展或自定义类的行为,这在复杂的软件系统中是一种常见的做法,尤其是在需要微调第三方库的行为时。

adaptive_ntk_forward 计算特定类型的位置编码

定义了一个名为 adaptive_ntk_forward 的方法,用于计算特定类型的位置编码,这种位置编码可能在序列模型(如Transformer模型)中使用。这里的位置编码是通过正余弦函数的变种实现的,其中包含了可调整的频率参数,使得编码可以自适应序列的长度。

def adaptive_ntk_forward(self, x, seq_len=None):
    if seq_len > self.max_seq_len_cached:
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        inv_freq = self.inv_freq
        dim = self.dim
        alpha = seq_len / 1024 - 1
        base = self.base * alpha ** (dim / (dim-2))
        # print(seq_len,alpha,base)
        # exit()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))

        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
        cos_cached = emb.cos()[None, None, :, :]
        sin_cached = emb.sin()[None, None, :, :]
        return (
            cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
        )
    return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
    )
  1. 参数说明:
    • self:类的实例。
    • x:输入数据,可能代表一个批次的序列。
    • seq_len:序列的长度。
  2. 条件检查:
    • if seq_len > self.max_seq_len_cached
    • 这个条件检查当前序列长度是否超过了缓存的最大长度。如果是,需要重新计算位置编码。
  3. 位置编码的计算:
    • 使用torch.arange(seq_len, ...)生成一个从0到seq_len-1的序列,用于生成时间步向量t
    • inv_freq:初始化为类属性,定义了频率的逆。
    • dim:嵌入维度,取自类属性。
    • alphabase的重新计算:基于序列长度调整base,以改变频率的计算。这里alpha是根据序列长度与一个基准长度(例如1024)的比例减1计算的,然后base通过一个幂函数调整,依赖于dim的值。
    • 更新inv_freq:使用调整后的base重新计算逆频率。
    • freqs:通过外积计算每个时间步的每个维度的频率。
    • emb:将正弦和余弦编码合并成一个嵌入矩阵。
    • cos_cachedsin_cached:分别计算余弦和正弦值,并缓存这些值。
  4. 返回位置编码:
    • 根据输入数据的类型,从缓存的cossin值中截取相应长度的数据,并返回。

这段代码与上文的adaptive_ntk_init方法有直接联系。adaptive_ntk_init方法负责初始化和调整类属性,如dimbase等,这些属性直接影响adaptive_ntk_forward方法中位置编码的生成。通过adaptive_ntk_init方法中的参数调整,可以使adaptive_ntk_forward方法在不同的序列长度和设备条件下更加高效和适应性强。这种设计显示了一种在复杂模型(如基于transformer的模型)中动态调整和缓存重要计算结果的方式,旨在提高模型处理不同输入大小时的性能和灵活性。

_init_adapter 部分

def _init_adapter(
        model: PreTrainedModel,
        model_args: ModelArguments,
        finetuning_args: FinetuningArguments,
        is_trainable: bool,
        is_mergeable: bool
) -> PreTrainedModel:
    r"""
    Initializes the adapters.

    Support full-parameter, freeze and LoRA training.

    Note that the trainable parameters must be cast to float32.
    """

    if finetuning_args.finetuning_type == "none" and is_trainable:
        raise ValueError("You cannot use finetuning_type=none while training.")

    if finetuning_args.finetuning_type == "full":
        logger.info("Fine-tuning method: Full")
        model = model.float()

    if finetuning_args.finetuning_type == "freeze":
        logger.info("Fine-tuning method: Freeze")

        for name, param in model.named_parameters():
            if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
                param.requires_grad_(False)
            else:
                param.data = param.data.to(torch.float32)

        if model_args.checkpoint_dir is not None:
            assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."

    if finetuning_args.finetuning_type == "lora":
        logger.info("Fine-tuning method: LoRA")
        lastest_checkpoint = None

        if model_args.checkpoint_dir is not None:
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
                "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
                "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."

            if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
                checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
            else:
                checkpoints_to_merge = model_args.checkpoint_dir

            for checkpoint in checkpoints_to_merge:
                model = PeftModel.from_pretrained(model, checkpoint)
                model = model.merge_and_unload()

            if len(checkpoints_to_merge) > 0:
                logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))

            if lastest_checkpoint is not None: # resume lora training or quantized inference
                model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)

        if is_trainable and lastest_checkpoint is None: # create new lora weights while training
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=finetuning_args.lora_rank,
                lora_alpha=finetuning_args.lora_alpha,
                lora_dropout=finetuning_args.lora_dropout,
                target_modules=finetuning_args.lora_target
            )
            model = get_peft_model(model, lora_config)

    if model_args.checkpoint_dir is not None:
        logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))

    return model

它负责初始化和配置用于微调的适配器。这种方法广泛应用于自然语言处理(NLP)中的预训练模型,如BERT、GPT等,以实现针对特定任务的微调。函数采用不同的微调策略,包括全参数微调、冻结部分参数微调和LoRA(低秩适配)微调。下面是对代码的详细解释:

参数和初步校验

  • 参数:
    • model: 这是一个预训练模型的实例,通常是来自 transformers 库的模型,比如 BERT 或 GPT。
    • model_args: 包含了模型相关的参数,例如模型架构的配置。
    • finetuning_args: 包含微调相关的设置,如微调策略(全参数、冻结、LoRA)和特定参数设置。
    • is_trainable: 表示模型的参数是否可训练。
    • is_mergeable: 表示是否可以合并多个预训练模型的检查点。
  • 校验:
    • 如果微调策略设置为 "none"is_trainable 为真,这是矛盾的,因为表示既不微调也要训练模型,所以会抛出一个值错误。

微调策略

  • 全参数微调 (full):
    • 在这种策略下,模型的所有参数都将参与训练,且模型的数据类型统一转换为 float,以保证数值计算的准确性。
    • 这种策略适用于任务与原始训练数据差异不大,或当有足够资源进行全面训练时。
  • 冻结参数微调 (freeze):
    • 在这种策略下,只有列表 finetuning_args.trainable_layers 中指定的层的参数会被设置为可训练。其他所有层的参数则被冻结,即不会在训练过程中更新。
    • 对参数进行冻结可以减少需要训练的参数数量,从而节省内存和计算资源,特别适用于当任务仅需要模型对特定部分进行微调时。
  • LoRA微调 (lora):
    • LoRA策略通过低秩适配的方式来微调模型。在这种策略中,不是直接修改原始参数,而是添加少量可训练的参数来调整原有参数的行为。
    • 检查点的合并和加载处理逻辑复杂,涉及到是否继续训练 LoRA 权重或合并多个模型检查点等多种情况的处理。
    • LoRA 微调允许模型在保持大部分预训练知识的基础上,通过调整极少数的额外参数来适应新任务。

加载检查点和返回模型

  • 如果 model_args 中指定了检查点目录,函数会在完成所有配置后,根据提供的检查点路径加载微调后的模型。这是确保模型配置正确并能够基于最新的训练状态启动训练或推理的关键步骤。
  • 函数最终返回配置好的模型实例,以供进一步训练或用于推理。

这个 _init_adapter 函数非常灵活,可以根据不同的需求和资源限制,通过多种微调策略来优化模型的性能和效率。这对于应对各种机器学习和NLP任务中遇到的不同挑战非常关键,特别是在处理大型语言模型时。

load_pretrained加载和配置预训练的模型

def load_pretrained(
        model_args: ModelArguments,
        finetuning_args: FinetuningArguments,
        is_trainable: Optional[bool] = False,
        stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
    r"""
    Loads pretrained model and tokenizer.

    Support both training and inference.
    """
    if (not is_trainable) and model_args.checkpoint_dir is None:
        logger.warning("Checkpoint is not found at evaluation, load the original model.")
        finetuning_args = FinetuningArguments(finetuning_type="none")

    assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
        "RM and PPO training can only be performed with the LoRA method."

    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=model_args.use_fast_tokenizer,
        padding_side="left",
        **config_kwargs
    )
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
        tokenizer.pad_token_id = 0 # set as the <unk> token

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    is_mergeable = True

    # Quantization configurations (using bitsandbytes library).
    if model_args.quantization_bit is not None:
        if model_args.quantization_bit == 8:
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
            config_kwargs["load_in_8bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0
            )

        elif model_args.quantization_bit == 4:
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
            require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
            require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
            require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
            config_kwargs["load_in_4bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=model_args.compute_dtype,
                bnb_4bit_use_double_quant=model_args.double_quantization,
                bnb_4bit_quant_type=model_args.quantization_type
            )

        is_mergeable = False
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

    if not is_trainable: # `device_map=auto` should be used for inference only
        config_kwargs["device_map"] = "auto"

    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
        model_to_load = model_args.checkpoint_dir[0]
    else:
        model_to_load = model_args.model_name_or_path

    # Load and prepare pretrained models (without valuehead).
    model = AutoModelForCausalLM.from_pretrained(
        model_to_load,
        config=config,
        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
        low_cpu_mem_usage=True,
        **config_kwargs
    )

    # Register auto class to save the custom code files.
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
        config.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
        tokenizer.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
        model.__class__.register_for_auto_class()

    # Initialize adapters
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
    model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

    if stage == "rm" or stage == "ppo": # add value head
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
            if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })

        if stage == "ppo": # load reward model
            assert is_trainable, "PPO stage cannot be performed at evaluation."
            assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
            logger.info("Load reward model from {}".format(model_args.reward_model))
            model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
            assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."

    if not is_trainable:
        model.requires_grad_(False) # fix all model params
        model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16

    print_trainable_params(model)

    return model, tokenizer

函数参数和预处理逻辑

  1. 函数参数:
    • model_args: 包含了模型的配置信息,例如模型名、版本、缓存目录等。
    • finetuning_args: 包含微调相关的配置,例如微调类型、相关层的配置等。
    • is_trainable: 布尔值,指示模型的参数是否应该设置为可训练。
    • stage: 表示特定训练阶段的标识符,如预训练("pt"),微调("sft"),或特定方法如强化学习("rm")和代理策略优化("ppo")。
  2. 预处理逻辑:
    • 初始校验确定在模型不可训练且未指定检查点的情况下,调整微调参数,避免加载非必要的预训练状态。

配置和分词器加载

  1. 配置构建:
    • 构建一个配置字典来指定从预训练库加载模型时使用的设置,例如是否信任远程代码,是否使用授权令牌等。
  2. 分词器加载:
    • 使用 AutoTokenizer.from_pretrained 方法加载预定义的分词器,这是处理文本输入的第一步。此处设置了如是否使用快速分词器,填充位置等。

模型配置和量化处理

  1. 量化配置:
    • 根据 model_args 中的量化位设置(如8位或4位),使用 bitsandbytes 库进行模型的量化,这可以大幅减少模型大小和提高推理速度。
  2. 模型加载:
    • 利用 AutoModelForCausalLM.from_pretrained 方法加载预训练模型,同时应用任何特定的量化和其他配置设置。此方法能够根据不同的配置加载不同架构的模型。

微调适配器和训练准备

  1. 适配器初始化:
    • _init_adapter 函数负责根据提供的参数配置模型,这包括设置哪些层是可训练的,以及如何处理模型的合并和分割。
  2. 训练准备:
    • 根据 is_trainable 确定是否需要进一步准备模型以进行训练,如添加特定的头部或应用特定的训练策略。

特殊阶段的处理

  1. 值头添加和配置:
    • 对于 "rm""ppo" 阶段,需要添加价值头(value head)。这通常用于基于模型预测的奖励机制,是强化学习和代理策略优化的关键组件。
  2. 奖励模型加载:
    • 特别是在 "ppo" 阶段,加载奖励模型对于优化代理策略至关重要。这涉及加载特定的适配器和配置,以确保模型能够根据奖励反馈进行优化。

模型最终设置

  1. 固定模型参数:
    • 如果不进行训练,所有模型参数设置为不可训练,并根据需要将模型数据类型转换为半精度(FP16)或其他量化形式,以优化性能。
  2. 调试支持:
    • 函数最后打印出模型的可训练参数,这有助于调试和确认模型配置正确。

返回值

  • 返回配置完成的模型和分词器实例,这对开始训练或进行模型推理是必需的。

这个 load_pretrained 函数非常复杂且功能丰富,它集成了多种不同的模型加载和配置技术。通过参数化和条件逻辑,该函数能够灵活地适应从基本微调到高级策略优化等各种不同的训练和推理需求。这种灵活性和强大的功能使得它可以高度定制化地适应不同的模型架构和任务需求,是现代深度学习框架中一个非常典型的实用功能实现。

prepare_args 解析参数

def prepare_args(
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
        model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()

    # Setup logging
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
    data_args.init_for_training()

    assert stage == "sft" or (not training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."

    assert not (training_args.do_train and training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True while training."

    assert (not training_args.do_predict) or training_args.predict_with_generate, \
        "Please enable `predict_with_generate` to save model predictions."

    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
        "Quantization is only compatible with the LoRA method."

    if model_args.checkpoint_dir is not None:
        if finetuning_args.finetuning_type != "lora":
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
        else:
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
                "Quantized model only accepts a single checkpoint."

    if model_args.quantization_bit is not None and (not training_args.do_train):
        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")

    if training_args.do_train and (not training_args.fp16):
        logger.warning("We recommend enable fp16 mixed precision training.")

    if data_args.prompt_template == "default":
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")

    if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
        logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
        training_args.ddp_find_unused_parameters = False

    training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

    if model_args.quantization_bit is not None:
        if training_args.fp16:
            model_args.compute_dtype = torch.float16
        elif training_args.bf16:
            model_args.compute_dtype = torch.bfloat16
        else:
            model_args.compute_dtype = torch.float32

    # Log on each process the small summary:
    logger.info(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
        + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args

定义了一个名为 prepare_args 的函数,它负责从命令行参数或 JSON 文件中解析各种训练、模型和数据相关的参数。这些参数被组织成不同的数据类,用于配置和控制模型的训练过程。函数还进行了一系列检查和设置,以确保模型训练和评估的正确性和最优化。以下是对代码中各部分的详细解释:

参数和解析逻辑

  1. 函数参数:
    • stage: 指定训练的阶段,可选值为 "pt", "sft", "rm", "ppo",分别表示不同的训练或应用阶段。
  2. 参数解析:
    • 使用 HfArgumentParser 解析器,这是由 Hugging Face 提供的一个工具,专门用于处理模型和训练相关的参数。
    • 如果命令行提供了一个以 .json 结尾的文件路径,参数将从该 JSON 文件中解析。这使得可以通过文件提前设定好所有的训练和模型参数。
    • 否则,参数将直接从命令行解析。

日志设置和参数校验

  1. 日志设置:
    • 根据 should_log 设置日志的详细级别,确保在训练和评估过程中有足够的信息输出。
    • 启用默认的日志处理器,并设定日志格式。
  2. 参数校验:
    • 根据不同的训练阶段,确保参数设置的合理性。例如,在某些阶段,predict_with_generate 不能被设置为 True
    • 对特定的参数组合进行断言,如在训练时不能启用 predict_with_generate
    • 对于使用量化的设置,检查是否与 LoRA 方法兼容。

特殊条件和警告

  1. 检查点和量化:
    • 如果指定了检查点目录,根据微调类型和量化设置进行进一步的校验。
    • 输出相关警告,如在非训练模式下使用量化可能导致性能降低。
  2. 混合精度和分布式训练:
    • 如果启用了训练但未使用 FP16 混合精度,发出警告建议启用以优化性能。
    • 在分布式训练设置中,检查并设置 ddp_find_unused_parameters 的值。

计算类型和优化器设置

  1. 计算类型:
    • 根据是否启用了 FP16 或 BF16 混合精度,设置模型的计算数据类型。
  2. 优化器设置:
    • 替换优化器设置以避免潜在的警告或错误。

日志记录和返回值

  1. 日志记录:
    • 在每个进程上输出关键的训练参数和系统配置的摘要,帮助监控和调试。
    • 记录完整的训练和评估参数信息。
  2. 返回值:
    • 返回解析和设置后的 ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, 和 FinetuningArguments 数据类实例。
    • 这些数据类包含了启动和控制模型训练所需的所有参数,确保模型可以按预期配置和执行。

通过这种方式,prepare_args 函数为模型训练和评估的启动提供了全面的参数配置和检查,确保了训练的顺利进行和最优化配置。

prepare_data 加载、验证和整合数据集

用于从不同的来源加载、验证和整合数据集,以便后续在机器学习模型中使用。函数详细处理了数据集的读取、格式识别、可选的完整性校验(通过 SHA-1 哈希)、数据选择和列重命名。以下是对代码中各部分的详细解释:

def prepare_data(
        model_args: ModelArguments,
        data_args: DataTrainingArguments
) -> Dataset:

    def checksum(file_path, hash):
        with open(file_path, "rb") as datafile:
            binary_data = datafile.read()
        sha1 = hashlib.sha1(binary_data).hexdigest()
        if sha1 != hash:
            logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))

    ext2type = {
        "csv": "csv",
        "json": "json",
        "jsonl": "json",
        "txt": "text"
    }

    max_samples = data_args.max_samples
    all_datasets: List[Dataset] = [] # support multiple datasets

    for dataset_attr in data_args.dataset_list:

        logger.info("Loading dataset {}...".format(dataset_attr))

        if dataset_attr.load_from == "hf_hub":
            data_path = dataset_attr.dataset_name
            data_files = None
        elif dataset_attr.load_from == "script":
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
            data_files = None
        elif dataset_attr.load_from == "file":
            data_path = None
            data_files: List[str] = []

            if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                    data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))

                    if data_path is None:
                        data_path = ext2type.get(data_files[0].split(".")[-1], None)
                    else:
                        assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
            elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
                data_path = ext2type.get(data_files[0].split(".")[-1], None)
            else:
                raise ValueError("File not found.")

            assert data_path, "File extension must be txt, csv, json or jsonl."

            if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
                checksum(data_files[0], dataset_attr.dataset_sha1)
            else:
                logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
        else:
            raise NotImplementedError

        raw_datasets = load_dataset(
            data_path,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None
        )
        dataset = raw_datasets[data_args.split]

        if max_samples is not None:
            max_samples_temp = min(len(dataset), max_samples)
            dataset = dataset.select(range(max_samples_temp))

        dummy_data = [None] * len(dataset)
        prefix_data = [dataset_attr.source_prefix] * len(dataset)
        for column_name, target_name in [
            ("prompt_column", "prompt"),
            ("query_column", "query"),
            ("response_column", "response"),
            ("history_column", "history")
        ]: # every dataset will have 4 columns same as each other
            if getattr(dataset_attr, column_name) != target_name:
                if getattr(dataset_attr, column_name):
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
                else: # None or empty string
                    dataset = dataset.add_column(target_name, dummy_data)
        dataset = dataset.add_column("prefix", prefix_data)
        all_datasets.append(dataset)

    if len(data_args.dataset_list) == 1:
        all_datasets = all_datasets[0]
    else:
        all_datasets = concatenate_datasets(all_datasets)

    return all_datasets

参数

  • model_args: 包含模型相关的参数,如缓存目录和认证令牌。
  • data_args: 包含数据处理和加载相关的参数,如数据集目录、数据集列表和最大样本数。

数据加载和校验

  1. 文件类型识别:
    • ext2type 字典用于根据文件扩展名确定数据的加载类型,支持 CSV、JSON、JSONL 和文本文件。
  2. 数据完整性校验:
    • checksum 函数通过计算文件的 SHA-1 哈希值并与提供的哈希值比较,来验证文件的完整性。如果不匹配,会记录一个警告。

数据集加载逻辑

  1. 从不同源加载:
    • 根据 data_args.dataset_list 中的指定,支持从 Hugging Face Hub、本地脚本或文件加载数据集。
    • 对于文件加载,支持从目录或单个文件加载,并且能够处理和校验多种文件类型。
  2. 数据集的后处理:
    • 使用 load_dataset 函数根据提供的路径或文件名加载数据。
    • 选择数据集的特定部分(如训练集、验证集)进行进一步处理。

数据筛选和格式调整

  1. 样本数限制:
    • 如果设置了 max_samples,则从数据集中选择前 max_samples 个样本。
  2. 列名标准化和添加:
    • 根据 dataset_attr 中的配置,重命名或添加列,确保每个数据集都有统一的列名(如 prompt、query、response、history)。
    • 添加 “prefix” 列,该列包含每个样本的前缀数据。

数据集整合

  1. 多数据集处理

    :

    • 如果只有一个数据集,直接返回该数据集。
    • 如果有多个数据集,使用 concatenate_datasets 将它们合并为一个大的数据集。

返回值

  • 函数返回处理和整合后的数据集(可能是单个数据集或合并后的数据集),这为模型训练或评估提供了准备就绪的数据。

总结

这个函数是数据预处理流程的关键部分,它不仅处理数据的加载和验证,还进行格式标准化和整合,确保不同来源和格式的数据能被模型有效处理。通过细致的错误检查和灵活的数据处理策略,该函数支持复杂的数据操作,使得最终的数据集能够满足具体的训练或评估需求。

preprocess_data 根据训练策略来整理数据

def preprocess_data(
        dataset: Dataset,
        tokenizer: PreTrainedTokenizer,
        data_args: DataTrainingArguments,
        training_args: Seq2SeqTrainingArguments,
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:

    column_names = list(dataset.column_names)
    prompt_template = Template(data_args.prompt_template)

    # support question with a single answer or multiple answers
    def get_dialog(examples):
        for i in range(len(examples["prompt"])):
            if examples["prompt"][i] and examples["response"][i]:
                query, answer = examples["prompt"][i], examples["response"][i]
                query = query + "\n" + examples["query"][i] if examples["query"][i] else query
                prefix = examples["prefix"][i] if examples["prefix"][i] else ""
                dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
                yield dialog

    def preprocess_pretrain_dataset(examples):
        # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
        text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
        concatenated_ids = list(chain(*text_ids))
        total_length = len(concatenated_ids)
        block_size = data_args.max_source_length - 1
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
        total_length = (total_length // block_size) * block_size
        # split by chunks of max_source_length
        result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
                  for i in range(0, total_length, block_size)]
        return {
            "input_ids": result,
            "labels": result.copy()
        }

    def preprocess_supervised_dataset(examples):
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
        # for input with history, we build multiple input-label pairs just like:
        # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
        model_inputs = {"input_ids": [], "labels": []}
        max_length = data_args.max_source_length + data_args.max_target_length

        for dialog in get_dialog(examples):
            input_ids, labels = [], []

            for i in range(len(dialog) // 2):
                source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
                target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)

                if len(source_ids) > data_args.max_source_length:
                    source_ids = source_ids[:data_args.max_source_length]
                if len(target_ids) > data_args.max_target_length - 1: # eos token
                    target_ids = target_ids[:data_args.max_target_length - 1]

                if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
                    break

                input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
                labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]

            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)

        return model_inputs

    def preprocess_unsupervised_dataset(examples):
        # build inputs with format `<bos> X` and labels with format `<bos> Y`
        model_inputs = {"input_ids": [], "labels": []}

        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            target_ids = tokenizer.encode(text=answer, add_special_tokens=True)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(target_ids) > data_args.max_target_length:
                target_ids = target_ids[:data_args.max_target_length]

            model_inputs["input_ids"].append(source_ids)
            model_inputs["labels"].append(target_ids)

        return model_inputs

    def preprocess_pairwise_dataset(examples):
        # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
        model_inputs = {"accept_ids": [], "reject_ids": []}
        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
            reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(accept_ids) > data_args.max_target_length - 1: # eos token
                accept_ids = accept_ids[:data_args.max_target_length - 1]
            if len(reject_ids) > data_args.max_target_length - 1: # eos token
                reject_ids = reject_ids[:data_args.max_target_length - 1]

            accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
            reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]

            model_inputs["accept_ids"].append(accept_ids)
            model_inputs["reject_ids"].append(reject_ids)
        return model_inputs

    def print_supervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
        print("label_ids:\n{}".format(example["labels"]))
        print("labels:\n{}".format(
            tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
                             skip_special_tokens=False)
        ))

    def print_pairwise_dataset_example(example):
        print("accept_ids:\n{}".format(example["accept_ids"]))
        print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
        print("reject_ids:\n{}".format(example["reject_ids"]))
        print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))

    def print_unsupervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

    if stage == "pt":
        preprocess_function = preprocess_pretrain_dataset
    elif stage == "sft":
        preprocess_function = preprocess_unsupervised_dataset \
            if training_args.predict_with_generate else preprocess_supervised_dataset
    elif stage == "rm":
        preprocess_function = preprocess_pairwise_dataset
    elif stage == "ppo":
        preprocess_function = preprocess_unsupervised_dataset

    with training_args.main_process_first(desc="dataset map pre-processing"):
        dataset = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset"
        )

        if stage == "pt":
            print_unsupervised_dataset_example(dataset[0])
        elif stage == "sft":
            print_supervised_dataset_example(dataset[0])
        elif stage == "rm":
            print_pairwise_dataset_example(dataset[0])
        elif stage == "ppo":
            print_unsupervised_dataset_example(dataset[0])

        return dataset

功能概览

preprocess_data 函数设计来根据不同的训练阶段(预训练、有监督训练、排名模型训练和策略优化)调整数据集的格式,使其适配特定的机器学习模型需求。它利用高度可配置的参数来灵活处理输入数据,并通过一系列定义良好的内部函数来实现这一目标。

参数详解

  • dataset: 该参数是一个包含原始数据的 Dataset 对象,通常包括多列,如文本、标签等。
  • tokenizer: 用于将文本字符串转换为模型可以处理的数值 token 序列。
  • data_args: 包含数据预处理所需的各种参数,如模板字符串、最大样本数量等。
  • training_args: 包含训练过程中配置的参数,例如批处理大小、是否进行预测生成等。
  • stage: 指明处理数据的具体阶段,每个阶段对数据的处理有不同的要求和目标。

数据处理流程

  1. 模板和提示生成:
    • 使用 data_args.prompt_template 初始化 Template 对象,这个模板用于生成特定格式的文本,适用于不同的训练阶段需求。
    • get_dialog 函数根据数据集中的列(如 “prompt”, “query”, “response”, “history”)组合生成完整的对话或文本序列。这对于构建复杂的输入模式如对话历史非常有用。
  2. 针对不同训练阶段的数据预处理策略:
    • 预训练 (preprocess_pretrain_dataset): 主要处理长文本,通过连接并分块来适应模型的输入尺寸,通常不包含特定任务的标签。
    • 有监督学习 (preprocess_supervised_dataset): 格式化数据以形成明确的输入-输出对,其中输入通常包含历史信息和当前的提示,输出是模型应该生成或预测的响应。
    • 无监督学习 (preprocess_unsupervised_dataset): 这种方式可能与有监督学习类似,但处理的自由度更高,标签可能直接是输入数据的变体。
    • 成对比较 (preprocess_pairwise_dataset): 特别用于需要模型评估两种或多种响应的场景,如排名模型训练或选择最佳回答。
  3. 批处理和多进程处理:
    • 使用 dataset.map 方法应用所选的预处理函数。这个步骤允许在多个处理器上并行处理数据,极大地提高了数据处理效率。
    • batched=True 参数确保数据以批量方式处理,而 num_proc 参数允许指定多个处理器。
  4. 调试和验证:
    • 根据不同阶段提供的打印函数(如 print_supervised_dataset_example),可以打印出处理后的样本,帮助开发者验证数据格式和内容的正确性。

总结

preprocess_data 函数是一个高度复杂且功能丰富的数据预处理工具,它通过灵活的参数和详尽的内部逻辑来适应多种数据处理需求。这使得它在准备数据以适应不同机器学习训练阶段时变得非常有效和可靠。通过细致的数据处理和优化,该函数帮助确保了数据集能够最大程度地支持模型训练的效果和效率,是构建高效机器学习工作流程中不可或缺的一部分。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值