common库部分类详解:高级微调框架

在我们的common库中,实现了一个复杂的深度学习训练和微调工具包,这个工具包主要是对前面编写的7个包的一个整合使用,包含多种用于模型训练和数据处理的函数和配置。以下是我负责的方法的详细说明:

1. 自定义适配器初始化 (_init_adapter)

  • 这个函数是用来初始化模型的适配器,并根据不同的微调策略调整模型参数。

  • 支持全参数微调(full)、参数冻结(freeze)和LoRA(低秩适配器)。

  • 对于LoRA微调,脚本支持从预训练检查点加载模型,合并检查点,以及创建新的LoRA权重。

  • 代码:

    def _init_adapter(
            model: PreTrainedModel,
            model_args: ModelArguments,
            finetuning_args: FinetuningArguments,
            is_trainable: bool,
            is_mergeable: bool
    ) -> PreTrainedModel:
        r"""
        Initializes the adapters.
    
        Support full-parameter, freeze and LoRA training.
    
        Note that the trainable parameters must be cast to float32.
        """
    
        if finetuning_args.finetuning_type == "none" and is_trainable:
            raise ValueError("You cannot use finetuning_type=none while training.")
    
        if finetuning_args.finetuning_type == "full":
            logger.info("Fine-tuning method: Full")
            model = model.float()
    
        if finetuning_args.finetuning_type == "freeze":
            logger.info("Fine-tuning method: Freeze")
    
            for name, param in model.named_parameters():
                if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
                    param.requires_grad_(False)
                else:
                    param.data = param.data.to(torch.float32)
    
            if model_args.checkpoint_dir is not None:
                assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
    
        if finetuning_args.finetuning_type == "lora":
            logger.info("Fine-tuning method: LoRA")
            lastest_checkpoint = None
    
            if model_args.checkpoint_dir is not None:
                assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
                    "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
                assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
                    "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
    
                if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
                    checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
                else:
                    checkpoints_to_merge = model_args.checkpoint_dir
    
                for checkpoint in checkpoints_to_merge:
                    model = PeftModel.from_pretrained(model, checkpoint)
                    model = model.merge_and_unload()
    
                if len(checkpoints_to_merge) > 0:
                    logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
    
                if lastest_checkpoint is not None: # resume lora training or quantized inference
                    model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
    
            if is_trainable and lastest_checkpoint is None: # create new lora weights while training
                lora_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM,
                    inference_mode=False,
                    r=finetuning_args.lora_rank,
                    lora_alpha=finetuning_args.lora_alpha,
                    lora_dropout=finetuning_args.lora_dropout,
                    target_modules=finetuning_args.lora_target
                )
                model = get_peft_model(model, lora_config)
    
        if model_args.checkpoint_dir is not None:
            logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
    
        return model

2. 改进LLaMA模型的旋转位置嵌入 (adaptive_ntk_initadaptive_ntk_forward)

  • 通过修改transformers.models.llama.modeling_llama.LlamaRotaryEmbedding__init__forward方法,增加了对序列长度动态调整的支持。

  • 这种改进使得模型可以更有效地处理不同长度的序列,尤其是在序列长度超过预训练时的最大长度时。

  • 代码:

    def _init_adapter(
            model: PreTrainedModel,
            model_args: ModelArguments,
            finetuning_args: FinetuningArguments,
            is_trainable: bool,
            is_mergeable: bool
    ) -> PreTrainedModel:
        r"""
        Initializes the adapters.
    
        Support full-parameter, freeze and LoRA training.
    
        Note that the trainable parameters must be cast to float32.
        """
    
        if finetuning_args.finetuning_type == "none" and is_trainable:
            raise ValueError("You cannot use finetuning_type=none while training.")
    
        if finetuning_args.finetuning_type == "full":
            logger.info("Fine-tuning method: Full")
            model = model.float()
    
        if finetuning_args.finetuning_type == "freeze":
            logger.info("Fine-tuning method: Freeze")
    
            for name, param in model.named_parameters():
                if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
                    param.requires_grad_(False)
                else:
                    param.data = param.data.to(torch.float32)
    
            if model_args.checkpoint_dir is not None:
                assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
    
        if finetuning_args.finetuning_type == "lora":
            logger.info("Fine-tuning method: LoRA")
            lastest_checkpoint = None
    
            if model_args.checkpoint_dir is not None:
                assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
                    "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
                assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
                    "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
    
                if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
                    checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
                else:
                    checkpoints_to_merge = model_args.checkpoint_dir
    
                for checkpoint in checkpoints_to_merge:
                    model = PeftModel.from_pretrained(model, checkpoint)
                    model = model.merge_and_unload()
    
                if len(checkpoints_to_merge) > 0:
                    logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
    
                if lastest_checkpoint is not None: # resume lora training or quantized inference
                    model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
    
            if is_trainable and lastest_checkpoint is None: # create new lora weights while training
                lora_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM,
                    inference_mode=False,
                    r=finetuning_args.lora_rank,
                    lora_alpha=finetuning_args.lora_alpha,
                    lora_dropout=finetuning_args.lora_dropout,
                    target_modules=finetuning_args.lora_target
                )
                model = get_peft_model(model, lora_config)
    
        if model_args.checkpoint_dir is not None:
            logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
    
        return model

3. 配置类和训练参数(调用前面实现的方法)

  • 使用HfArgumentParser解析命令行参数,设置模型参数和训练参数。

  • 定义了多个配置类来处理模型参数(ModelArguments)、数据训练参数(DataTrainingArguments)、微调参数(FinetuningArguments)和生成参数(GeneratingArguments)。

    from .config import (
        ModelArguments,
        DataTrainingArguments,
        FinetuningArguments,
        GeneratingArguments
    )

4. 日志和辅助功能(调用前面实现的方法)

  • get_logger用于配置和获取日志实例,方便跟踪和记录训练过程中的事件。

  • 其他辅助函数如load_trainable_paramsprepare_model_for_training用于加载训练参数和准备模型。

    from .other import (
        get_logger,
        load_trainable_params,
        load_valuehead_params,
        print_trainable_params,
        prepare_model_for_training,
        IGNORE_INDEX
    )

common库的这一部分为使用因果语言模型进行高级微调提供了一个框架。它允许其他成员通过使用不同的微调策略,如全参数微调、参数冻结和LoRA,来优化模型性能和应用适配性。此外,通过对模型的结构进行自定义修改,它支持更加灵活和动态的训练需求,特别是处理不同长度的输入数据。这种灵活性和功能的扩展使得它非常适合用于实验和实际应用中对NLP模型的精细调整和优化。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值