模型的训练

模型的训练

自定义旋转嵌入模块和微调策略来优化模型性能

自定义旋转嵌入模块

在深度学习中,位置编码是一种关键技术,用于模型理解输入序列中各个位置的相对关系。

 old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
 def adaptive_ntk_init(self, dim, max_position_embeddings=4096, base=10000, device=None):
     self.dim = dim
     self.base = base
     old_init(self, dim, max_position_embeddings, base, device)
 ​
 def adaptive_ntk_forward(self, x, seq_len=None):
     if seq_len > self.max_seq_len_cached:
         t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
         inv_freq = self.inv_freq
         dim = self.dim
         alpha = seq_len / 1024 - 1
         base = self.base * alpha ** (dim / (dim-2))
         # print(seq_len,alpha,base)
         # exit()
         inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim ))
 ​
         freqs = torch.einsum("i,j->ij", t, inv_freq)
         emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
         cos_cached = emb.cos()[None, None, :, :]
         sin_cached = emb.sin()[None, None, :, :]
         return (
             cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
             sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
         )
     return (
         self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
         self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
     )
 transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
 transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init

通过替换 LlamaRotaryEmbedding类的初始化和前向传播方法,实现了对其行为的自定义修改。具体修改包括在初始化方法中保存额外的参数,并在前向传播方法中根据输入的 seq_len 动态计算余弦和正弦矩阵,从而实现了更灵活和适应性更强的功能。

adaptive_ntk_init 方法的作用是在初始化 LlamaRotaryEmbedding 类的实例时,设置其维度 dim 和基础值 base,并且通过调用 old_init 方法完成了类的原始初始化过程,确保了其他默认行为和设置能够正常工作。

adaptive_ntk_forward 方法的作用是根据输入的序列长度 seq_len 动态地生成适应性的余弦和正弦频率张量,用于模型的前向传播过程。允许根据序列长度动态地调整频率信息,从而提高模型在不同序列长度下的表现。

微调

根据不同的微调策略和配置进行操作

 def _init_adapter(
         model: PreTrainedModel,
         model_args: ModelArguments,
         finetuning_args: FinetuningArguments,
         is_trainable: bool,
         is_mergeable: bool
 ) -> PreTrainedModel:
     r"""
     Initializes the adapters.
 ​
     Support full-parameter, freeze and LoRA training.
 ​
     Note that the trainable parameters must be cast to float32.
     """
 ​
  • model: 需要应用适配器的 PreTrainedModel 实例。

  • model_args: 与模型相关的配置参数。

  • finetuning_args: 微调过程中的参数和策略。

  • is_trainable: 指示模型是否可以训练的布尔值。

  • is_mergeable: 指示是否可以合并微调过程中的检查点的布尔值。

     if finetuning_args.finetuning_type == "none" and is_trainable:
         raise ValueError("You cannot use finetuning_type=none while training.")
 ​
     if finetuning_args.finetuning_type == "full":
         logger.info("Fine-tuning method: Full")
         model = model.float()
 ​
     if finetuning_args.finetuning_type == "freeze":
         logger.info("Fine-tuning method: Freeze")
 ​
         for name, param in model.named_parameters():
             if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
                 param.requires_grad_(False)
             else:
                 param.data = param.data.to(torch.float32)
 ​
         if model_args.checkpoint_dir is not None:
             assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
 ​
     if finetuning_args.finetuning_type == "lora":
         logger.info("Fine-tuning method: LoRA")
         lastest_checkpoint = None
 ​
         if model_args.checkpoint_dir is not None:
             assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
                 "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
             assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
                 "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
 ​
             if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
                 checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
             else:
                 checkpoints_to_merge = model_args.checkpoint_dir
 ​
             for checkpoint in checkpoints_to_merge:
                 model = PeftModel.from_pretrained(model, checkpoint)
                 model = model.merge_and_unload()
 ​
             if len(checkpoints_to_merge) > 0:
                 logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
 ​
             if lastest_checkpoint is not None: # resume lora training or quantized inference
                 model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
 ​
         if is_trainable and lastest_checkpoint is None: # create new lora weights while training
             lora_config = LoraConfig(
                 task_type=TaskType.CAUSAL_LM,
                 inference_mode=False,
                 r=finetuning_args.lora_rank,
                 lora_alpha=finetuning_args.lora_alpha,
                 lora_dropout=finetuning_args.lora_dropout,
                 target_modules=finetuning_args.lora_target
             )
             model = get_peft_model(model, lora_config)
 ​
     if model_args.checkpoint_dir is not None:
         logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
 ​
     return model
  1. 验证微调策略:

    如果 finetuning_args.finetuning_type == "none"is_trainableTrue,则会引发 ValueError

  2. 完整微调:

    如果 finetuning_args.finetuning_type == "full",则将模型转换为 float 类型,表示可以对所有参数进行微调。

  3. 冻结部分层:

    如果 finetuning_args.finetuning_type == "freeze",则根据 finetuning_args.trainable_layers 中指定的层名,将不在此列表中的参数的 requires_grad 设为 False,同时将这些参数的数据类型转换为 torch.float32

  4. LoRA 训练:

    • 如果 finetuning_args.finetuning_type == "lora",则执行 LoRA(Layer-wise Relevance Adaptive Training)的训练策略。

    • 首先检查是否有预训练的 LoRA 检查点,根据需要加载并合并这些检查点。如果需要继续训练或者不合并的情况下,选择最新的检查点。

    • 如果需要在训练过程中创建新的 LoRA 权重,根据给定的参数配置创建 LoraConfig 并使用 get_peft_model 函数获取适当的 PeftModel

  5. 加载模型检查点: 如果 model_args.checkpoint_dir 不为 None,则加载微调后的模型检查点,并记录日志显示加载了哪些检查点。

最终返回经过适配器初始化后的 PreTrainedModel 实例。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值