模型的训练
自定义旋转嵌入模块和微调策略来优化模型性能
自定义旋转嵌入模块
在深度学习中,位置编码是一种关键技术,用于模型理解输入序列中各个位置的相对关系。
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ def adaptive_ntk_init(self, dim, max_position_embeddings=4096, base=10000, device=None): self.dim = dim self.base = base old_init(self, dim, max_position_embeddings, base, device) def adaptive_ntk_forward(self, x, seq_len=None): if seq_len > self.max_seq_len_cached: t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq dim = self.dim alpha = seq_len / 1024 - 1 base = self.base * alpha ** (dim / (dim-2)) # print(seq_len,alpha,base) # exit() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) cos_cached = emb.cos()[None, None, :, :] sin_cached = emb.sin()[None, None, :, :] return ( cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) ) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype) ) transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init
通过替换 LlamaRotaryEmbedding类的初始化和前向传播方法,实现了对其行为的自定义修改。具体修改包括在初始化方法中保存额外的参数,并在前向传播方法中根据输入的 seq_len
动态计算余弦和正弦矩阵,从而实现了更灵活和适应性更强的功能。
adaptive_ntk_init 方法的作用是在初始化 LlamaRotaryEmbedding 类的实例时,设置其维度 dim 和基础值 base,并且通过调用 old_init 方法完成了类的原始初始化过程,确保了其他默认行为和设置能够正常工作。
adaptive_ntk_forward 方法的作用是根据输入的序列长度 seq_len 动态地生成适应性的余弦和正弦频率张量,用于模型的前向传播过程。允许根据序列长度动态地调整频率信息,从而提高模型在不同序列长度下的表现。
微调
根据不同的微调策略和配置进行操作
def _init_adapter( model: PreTrainedModel, model_args: ModelArguments, finetuning_args: FinetuningArguments, is_trainable: bool, is_mergeable: bool ) -> PreTrainedModel: r""" Initializes the adapters. Support full-parameter, freeze and LoRA training. Note that the trainable parameters must be cast to float32. """
-
model
: 需要应用适配器的PreTrainedModel
实例。 -
model_args
: 与模型相关的配置参数。 -
finetuning_args
: 微调过程中的参数和策略。 -
is_trainable
: 指示模型是否可以训练的布尔值。 -
is_mergeable
: 指示是否可以合并微调过程中的检查点的布尔值。
if finetuning_args.finetuning_type == "none" and is_trainable: raise ValueError("You cannot use finetuning_type=none while training.") if finetuning_args.finetuning_type == "full": logger.info("Fine-tuning method: Full") model = model.float() if finetuning_args.finetuning_type == "freeze": logger.info("Fine-tuning method: Freeze") for name, param in model.named_parameters(): if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): param.requires_grad_(False) else: param.data = param.data.to(torch.float32) if model_args.checkpoint_dir is not None: assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") lastest_checkpoint = None if model_args.checkpoint_dir is not None: assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \ "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir for checkpoint in checkpoints_to_merge: model = PeftModel.from_pretrained(model, checkpoint) model = model.merge_and_unload() if len(checkpoints_to_merge) > 0: logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) if lastest_checkpoint is not None: # resume lora training or quantized inference model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) if is_trainable and lastest_checkpoint is None: # create new lora weights while training lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=finetuning_args.lora_rank, lora_alpha=finetuning_args.lora_alpha, lora_dropout=finetuning_args.lora_dropout, target_modules=finetuning_args.lora_target ) model = get_peft_model(model, lora_config) if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) return model
-
验证微调策略:
如果
finetuning_args.finetuning_type == "none"
且is_trainable
为True
,则会引发ValueError
-
完整微调:
如果
finetuning_args.finetuning_type == "full"
,则将模型转换为float
类型,表示可以对所有参数进行微调。 -
冻结部分层:
如果
finetuning_args.finetuning_type == "freeze"
,则根据finetuning_args.trainable_layers
中指定的层名,将不在此列表中的参数的requires_grad
设为False
,同时将这些参数的数据类型转换为torch.float32
。 -
LoRA 训练:
-
如果
finetuning_args.finetuning_type == "lora"
,则执行 LoRA(Layer-wise Relevance Adaptive Training)的训练策略。 -
首先检查是否有预训练的 LoRA 检查点,根据需要加载并合并这些检查点。如果需要继续训练或者不合并的情况下,选择最新的检查点。
-
如果需要在训练过程中创建新的 LoRA 权重,根据给定的参数配置创建
LoraConfig
并使用get_peft_model
函数获取适当的PeftModel
。
-
-
加载模型检查点: 如果
model_args.checkpoint_dir
不为None
,则加载微调后的模型检查点,并记录日志显示加载了哪些检查点。
最终返回经过适配器初始化后的 PreTrainedModel 实例。