huggingface的self._maybe_log_save_evaluate、self.save_model、self._save源码解读(权重等内容保存)


前言

在 Hugging Face 中,self._maybe_log_save_evaluate是有关权重等内容相关保存函数。本文通过该函数探索huggingface内部源码对权重相关文件保存方法,以供读者了解huggingface保存权重文件原理。


一、self.state与self.control初始化

请参考huggingface专栏中的huggingface的self.state与self.control来源(TrainerState与TrainerControl)文章点击这里

二、self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)源码解读

这个函数我们也在博客有解读,也可以作为补充参考。我这里按照源码流程逐渐解读器内容。

1、_maybe_log_save_evaluate完整源码

 def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
      if self.control.should_log:
          if is_torch_tpu_available():
              xm.mark_step()

          logs: Dict[str, float] = {}

          # all_gather + mean() to get average loss over all processes
          tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

          # reset tr_loss to zero
          tr_loss -= tr_loss

          logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
          logs["learning_rate"] = self._get_learning_rate()

          self._total_loss_scalar += tr_loss_scalar
          self._globalstep_last_logged = self.state.global_step
          self.store_flos()

          self.log(logs)

      metrics = None
      if self.control.should_evaluate:
          if isinstance(self.eval_dataset, dict):
              metrics = {}
              for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                  dataset_metrics = self.evaluate(
                      eval_dataset=eval_dataset,
                      ignore_keys=ignore_keys_for_eval,
                      metric_key_prefix=f"eval_{eval_dataset_name}",
                  )
                  metrics.update(dataset_metrics)
          else:
              metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
          self._report_to_hp_search(trial, self.state.global_step, metrics)

          # Run delayed LR scheduler now that metrics are populated
          if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
              metric_to_check = self.args.metric_for_best_model
              if not metric_to_check.startswith("eval_"):
                  metric_to_check = f"eval_{metric_to_check}"
              self.lr_scheduler.step(metrics[metric_to_check])

      if self.control.should_save:
          self._save_checkpoint(model, trial, metrics=metrics)
          self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2、control.should_log与is_torch_tpu_available

self.control.should_log来源self.control,is_torch_tpu_available是判断torch_xla安装,若安装是TPU环境条件使用,其调用源码如下:

if self.control.should_log:
   if is_torch_tpu_available():
      xm.mark_step()

而is_torch_tpu_available源码如下:

@lru_cache()
def is_torch_tpu_available(check_device=True):
    "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
    if not _torch_available:
        return False
    if importlib.util.find_spec("torch_xla") is not None:
        if check_device:
            # We need to check if `xla_device` can be found, will raise a RuntimeError if not
            try:
                import torch_xla.core.xla_model as xm

                _ = xm.xla_device()
                return True
            except RuntimeError:
                return False
        return True
    return False

3、评估(self.control.should_evaluate)

需要评估方法,得到评估值metrics,这个一般忽略!


metrics = None
if self.control.should_evaluate:
    if isinstance(self.eval_dataset, dict):
        metrics = {}
        for eval_dataset_name, eval_dataset in self.eval_dataset.items():
            dataset_metrics = self.evaluate(
                eval_dataset=eval_dataset,
                ignore_keys=ignore_keys_for_eval,
                metric_key_prefix=f"eval_{eval_dataset_name}",
            )
            metrics.update(dataset_metrics)
    else:
        metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
    self._report_to_hp_search(trial, self.state.global_step, metrics)

    # Run delayed LR scheduler now that metrics are populated
    if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        metric_to_check = self.args.metric_for_best_model
        if not metric_to_check.startswith("eval_"):
            metric_to_check = f"eval_{metric_to_check}"
        self.lr_scheduler.step(metrics[metric_to_check])

4、保存权重

这里,self.control.should_save由self.control给出,这里与权重保存相关是self._save_checkpoint函数调用。而self.control = self.callback_handler.on_save(self.args, self.state, self.control)是对self.control的更新,

if self.control.should_save:
   self._save_checkpoint(model, trial, metrics=metrics)
   self.control = self.callback_handler.on_save(self.args, self.state, self.control)

5、self.callback_handler类

这里,我给出self.callback_handler相关内容。因为,huggingface中有很多处self.control = self.callback_handler.*等的方法,其大致逻辑就是下面的类。我们发现self.callback_handler.on_train_begin、on_epoch_end、on_save等都是调用self.call_event("名称", args, state, control,其它参数)方法,实际就是参数更新方式,其所有源码如下:

class CallbackHandler(TrainerCallback):
    """Internal class that just calls the list of callbacks in order."""

    def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
        self.callbacks = []
        for cb in callbacks:
            self.add_callback(cb)
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.train_dataloader = None
        self.eval_dataloader = None

        if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
            logger.warning(
                "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
                + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
                + "callbacks is\n:"
                + self.callback_list
            )

    def add_callback(self, callback):
        cb = callback() if isinstance(callback, type) else callback
        cb_class = callback if isinstance(callback, type) else callback.__class__
        if cb_class in [c.__class__ for c in self.callbacks]:
            logger.warning(
                f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
                + "list of callbacks is\n:"
                + self.callback_list
            )
        self.callbacks.append(cb)

    def pop_callback(self, callback):
        if isinstance(callback, type):
            for cb in self.callbacks:
                if isinstance(cb, callback):
                    self.callbacks.remove(cb)
                    return cb
        else:
            for cb in self.callbacks:
                if cb == callback:
                    self.callbacks.remove(cb)
                    return cb

    def remove_callback(self, callback):
        if isinstance(callback, type):
            for cb in self.callbacks:
                if isinstance(cb, callback):
                    self.callbacks.remove(cb)
                    return
        else:
            self.callbacks.remove(callback)

    @property
    def callback_list(self):
        return "\n".join(cb.__class__.__name__ for cb in self.callbacks)

    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_init_end", args, state, control)

    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_training_stop = False
        return self.call_event("on_train_begin", args, state, control)

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_train_end", args, state, control)

    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_epoch_stop = False
        return self.call_event("on_epoch_begin", args, state, control)

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_epoch_end", args, state, control)

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_log = False
        control.should_evaluate = False
        control.should_save = False
        return self.call_event("on_step_begin", args, state, control)

    def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_substep_end", args, state, control)

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_step_end", args, state, control)

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
        control.should_evaluate = False
        return self.call_event("on_evaluate", args, state, control, metrics=metrics)

    def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
        return self.call_event("on_predict", args, state, control, metrics=metrics)

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        control.should_save = False
        return self.call_event("on_save", args, state, control)

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
        control.should_log = False
        return self.call_event("on_log", args, state, control, logs=logs)

    def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
        return self.call_event("on_prediction_step", args, state, control)

    def call_event(self, event, args, state, control, **kwargs):
        for callback in self.callbacks:
            result = getattr(callback, event)(
                args,
                state,
                control,
                model=self.model,
                tokenizer=self.tokenizer,
                optimizer=self.optimizer,
                lr_scheduler=self.lr_scheduler,
                train_dataloader=self.train_dataloader,
                eval_dataloader=self.eval_dataloader,
                **kwargs,
            )
            # A Callback can skip the return of `control` if it doesn't change it.
            if result is not None:
                control = result
        return control

这里, def call_event(self, event, args, state, control, **kwargs):非常重要,是self.control参数返回内容。

三、self._save_checkpoint(model, trial, metrics=metrics)源码解读

在上面最重要的是

1、完整的源码

给出_save_checkpoint所有源码,如下:

    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
		#在所有情况下,包括ddp/dp/deepspeed,self.model始终是我们想要保存的模型的引用,除了FullyShardedDDP。 assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" 的中文翻译是:"在所有情况下,包括ddp/dp/deepspeed,self.model始终是我们想要保存的模型的引用,除了FullyShardedDDP。"
        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)
        if self.is_deepspeed_enabled:
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
            # config `stage3_gather_16bit_weights_on_model_save` is True
            self.model_wrapped.save_checkpoint(output_dir)

        # Save optimizer and scheduler
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
            self.optimizer.consolidate_state_dict()

        if self.fsdp or self.is_fsdp_enabled:
            if self.is_fsdp_enabled:
                save_fsdp_optimizer(
                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
                )
            else:
                # FSDP has a different interface for saving optimizer states.
                # Needs to be called on all ranks to gather all states.
                # full_optim_state_dict will be deprecated after Pytorch 2.2!
                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
        elif is_sagemaker_mp_enabled():
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
        elif self.args.should_save and not self.is_deepspeed_enabled:
            # deepspeed.save_checkpoint above saves model/optim/sched
            if self.fsdp and not self.is_fsdp_enabled:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)
            if self.do_grad_scaling:
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.args.should_save:
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) # TRAINER_STATE_NAME=trainer_state.json

        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)

        if self.args.world_size <= 1:
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2、获得保存路径与模型参数计算

output_dir为最终保存路径(如:‘./out_dirs/checkpoint-1’),self.store_flos()为模型参数计算。

 # Save model checkpoint
 checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

 if self.hp_search_backend is None and trial is None:
     self.store_flos()

 run_dir = self._get_output_dir(trial=trial)
 output_dir = os.path.join(run_dir, checkpoint_folder)

self.store_flos()参数计算

    def store_flos(self):
        # Storing the number of floating-point operations that went into the model
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
            self.current_flos = 0
        else:
            self.state.total_flos += self.current_flos
            self.current_flos = 0

3、 self.save_model(output_dir, _internal_call=True)源码解读

这段代码是一个用于保存模型的方法 save_model。在这个方法中,根据不同的条件和设置,选择性地保存模型的权重参数。方法首先检查是否在主进程中运行,然后根据不同的训练环境和设置选择合适的保存方式。其中涵盖了对不同分布式训练策略(如DDP、DeepSpeed等)的处理逻辑,以及对应用于不同训练环境的保存方式。最后,如果设置为将模型推送到Hub(Hub是Hugging Face提供的模型版本控制平台),则会在保存模型后将模型推送到Hub。

 self.save_model(output_dir, _internal_call=True)

4、self.sharded_ddp == ShardedDDPOption.SIMPLE条件保存optimizer and scheduler状态

这段代码是Hugging Face中用于保存优化器和学习率调度器状态的部分。在这里,如果使用的分布式训练策略是ShardedDDPOption.SIMPLE,则调用了optimizer对象的consolidate_state_dict()方法,这个方法的作用是将优化器的状态字典整合成一个单独的状态字典。

通过调用consolidate_state_dict()方法,可以将分布式训练中每个处理器上的优化器状态整合到一个共享的状态字典中,以便在需要时能够正确保存和恢复整个优化器的状态。这有助于在分布式设置中更好地管理和同步优化器状态,确保训练的正确性和一致性。

 # Save optimizer and scheduler
 if self.sharded_ddp == ShardedDDPOption.SIMPLE:
     self.optimizer.consolidate_state_dict()

5、self.fsdp or self.is_fsdp_enabled条件保存优化器状态

这段代码是Hugging Face中用于在使用FullyShardedDDP(FSDP)分布式训练策略时保存优化器状态的部分。在这里,首先检查是否启用了FSDP或者FSDP已经被设置,然后根据情况调用相应的保存优化器状态的函数。

如果self.is_fsdp_enabled为True,则调用save_fsdp_optimizer()函数来保存FSDP优化器的状态。这个函数会通过传递FSDP插件、加速器、优化器、模型和输出目录等参数来保存FSDP优化器的状态。

如果self.is_fsdp_enabled为False,则说明FSDP已经被设置但未启用,这时候会调用FSDP的另一种接口来保存优化器状态。这个接口需要在所有进程中调用以收集所有状态,并且提示在PyTorch 2.2版本后full_optim_state_dict方法将被弃用。

总的来说,这段代码负责根据FSDP的状态和设置来选择合适的方式保存优化器状态,确保在使用FSDP进行分布式训练时能够正确地保存和恢复优化器状态。

if self.fsdp or self.is_fsdp_enabled:
    if self.is_fsdp_enabled:
        save_fsdp_optimizer(
            self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
        )
    else:
        # FSDP has a different interface for saving optimizer states.
        # Needs to be called on all ranks to gather all states.
        # full_optim_state_dict will be deprecated after Pytorch 2.2!
        full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

6、不同条件保存optimizer 、 scheduler and scale等状态

这段代码是 Hugging Face 中用于保存优化器和学习率调度器状态的部分。根据不同的训练环境和设置,选择性地保存优化器和学习率调度器的状态。

如果正在使用 Torch TPU,代码会使用 xm.rendezvous(“saving_optimizer_states”) 来同步进程,然后保存优化器和学习率调度器的状态到指定的输出目录。
如果启用了 SageMaker Model Parallelism(SMP),代码会在本地保存优化器状态字典,然后通过 smp.barrier() 来同步进程,最终根据条件决定是否在第一个进程保存优化器状态到指定的输出目录。
如果不是在 DeepSpeed 环境下且需要保存模型,代码会根据条件选择是否保存完整的优化器状态字典或者只保存特定部分的状态到指定的输出目录。
总的来说,这段代码负责根据不同的训练环境和设置来保存优化器和学习率调度器的状态,确保在需要时能够正确地保存和恢复这些状态,以便在训练过程中进行必要的操作。

        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
        elif is_sagemaker_mp_enabled():
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
        elif self.args.should_save and not self.is_deepspeed_enabled:
            # deepspeed.save_checkpoint above saves model/optim/sched
            if self.fsdp and not self.is_fsdp_enabled:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)
            if self.do_grad_scaling:
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

OPTIMIZER_NAME=‘optimizer.pt’
SCHEDULER_NAME=‘scheduler.pt’
SCALER_NAME=‘scaler.pt’

7、Determine the new best metric / best model checkpoint

这段代码是 Hugging Face 中用于确定最佳指标和最佳模型检查点的部分。在这里,首先检查是否提供了指标(metrics)以及指定了用于选择最佳模型的指标(self.args.metric_for_best_model)。

如果满足条件,会根据指定的指标名称来获取对应的指标值。
然后根据设置中的 greater_is_better 参数来确定比较操作符是大于还是小于。
最后,如果当前的指标值比之前记录的最佳指标值要好(根据 greater_is_better 参数决定),则更新记录的最佳指标值和最佳模型检查点路径为当前的指标值和输出目录。
这段代码的作用是在训练过程中根据指定的指标来判断当前模型是否为最佳模型,并记录下最佳指标值和对应的模型检查点路径,以便在训练结束后能够找到表现最好的模型。

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
    metric_to_check = self.args.metric_for_best_model
    if not metric_to_check.startswith("eval_"):
        metric_to_check = f"eval_{metric_to_check}"
    metric_value = metrics[metric_to_check]

    operator = np.greater if self.args.greater_is_better else np.less
    if (
        self.state.best_metric is None
        or self.state.best_model_checkpoint is None
        or operator(metric_value, self.state.best_metric)
    ):
        self.state.best_metric = metric_value
        self.state.best_model_checkpoint = output_dir

8、self.state状态保存

只要满足self.args.should_save条件,调用self.state.save_to_json来保存state的状态,这个内容可参考博客


 # Save the Trainer state
 if self.args.should_save:
     self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) # TRAINER_STATE_NAME=trainer_state.json

9、随机状态数保存

这段代码是关于在非分布式训练中保存随机数生成器状态的。首先,它保存了不同库(Python、NumPy、PyTorch)的随机数生成器状态,以及在CUDA和TPU环境下的状态。在分布式训练中,如果检测到CUDA可用并且并行模式为非分布式,则会保存全局的CUDA随机数生成器状态;否则,保存当前CUDA随机数生成器状态。
接着,代码检查是否TPU可用,如果是的话,保存XLA(PyTorch XLA)的随机数生成器状态。
在保存随机数生成器状态之前,为了避免出现输出目录还不存在的情况,代码会先创建输出目录(output_dir)。
最后,根据进程数量(world_size)来决定是将随机数生成器状态保存为单个文件还是多个文件。如果进程数量小于等于1,则保存为单个文件"rng_state.pth";否则,根据进程索引(process_index)保存为多个文件,文件名格式为"rng_state_{self.args.process_index}.pth"。


# Save RNG state in non-distributed training
rng_states = {
    "python": random.getstate(),
    "numpy": np.random.get_state(),
    "cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
    if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
        # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
        rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
    else:
        rng_states["cuda"] = torch.cuda.random.get_rng_state()

if is_torch_tpu_available():
    rng_states["xla"] = xm.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)

if self.args.world_size <= 1:
    torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
    torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))

注:名称为rng_state.pth或rng_state_{self.args.process_index}.pth

10、控制权重限制保存

这段代码是循环查看保存权重数量限制功能。
首先,它会检查是否设置了保存的总数限制(save_total_limit),如果未设置或者设置为小于等于0,则直接返回。

接着,代码会对已经存在的检查点按照一定规则排序,然后判断是否需要删除旧的检查点。如果排序后的检查点数量不超过保存总数限制,则直接返回。

然后,代码会判断一些特殊情况,比如当save_total_limit为1且同时设置了load_best_model_at_end为True时,为了避免删除最后一个检查点(用于允许恢复训练),会将save_total_limit设为2。

最后,根据计算出需要删除的检查点数量,选择要删除的检查点,并逐个删除。删除时会输出日志信息,指明正在删除的旧检查点。

 # Maybe delete some older checkpoints.
 if self.args.should_save:
     self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

而_rotate_checkpoints源码如下:

def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
     if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
         return

     # Check if we should delete older checkpoint(s)
     checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
     if len(checkpoints_sorted) <= self.args.save_total_limit:
         return

     # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
     # we don't do to allow resuming.
     save_total_limit = self.args.save_total_limit
     if (
         self.state.best_model_checkpoint is not None
         and self.args.save_total_limit == 1
         and checkpoints_sorted[-1] != self.state.best_model_checkpoint
     ):
         save_total_limit = 2

     number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
     checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
     for checkpoint in checkpoints_to_be_deleted:
         logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
         shutil.rmtree(checkpoint, ignore_errors=True)

四、self.save_model(output_dir, _internal_call=True)

1、save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False)源码

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
	if output_dir is None:
       output_dir = self.args.output_dir

    if is_torch_tpu_available():
        self._save_tpu(output_dir)
    elif is_sagemaker_mp_enabled():
        # Calling the state_dict needs to be done on the wrapped model and on all processes.
      	...
    elif (ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or 
    ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None or self.is_fsdp_enabled
    ):
        ...

    elif self.is_deepspeed_enabled:
       ...
    elif self.args.should_save:
        self._save(output_dir)
    # Push to the Hub when `save_model` is called by the user.
    if self.args.push_to_hub and not _internal_call:
        self.push_to_hub(commit_message="Model save")

2、_save(self, output_dir: Optional[str] = None, state_dict=None)源码解读(来源_save_model)

此部分保存config.json文件、模型权重文件pytorch_model.bin与trainer的self.args内容training_args.bin。该文件调用来源def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):此函数代码内容。

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(unwrap_model(self.model), supported_classes):
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

当然,若self.tokenizer is not None也会保存tokenizer相关内容self.tokenizer.save_pretrained(output_dir)

若是使用self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors )保存方式,可保存模型训练权重与config.json文件。

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        is_main_process: bool = True,
        state_dict: Optional[dict] = None,
        save_function: Callable = torch.save,
        push_to_hub: bool = False,
        max_shard_size: Union[int, str] = "10GB",
        safe_serialization: bool = False,
        variant: Optional[str] = None,
        **kwargs,
    ):

同时,我们知道TRAINING_ARGS_NAME=training_args.bin,就是保存的trainer中的self.args内容。

 #Good practice: save your training arguments together with the trained model
 torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

五、更改

当然我在博客写了一个huggingface的Demo。这里,我借助此Demo说明一下,自己更改某些参数内容。

trainer的args参数修改

比如我想对huggingface的Trainer中添加某参数,可使用trainer.args.a=20实现a变量添加。当然,你也可以通过这种方式来修改某些参数,已使实现你的需求。

# 定义Trainer对象
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

trainer.args.a=20
# 开始训练
trainer.train()
trainer.save_model("/huggingface_demo/out_dirs")

总结

之所以,单独写出来,是要强调,我们可以使用这种方式灵活实现自己需求。特别是使用save_model等函数来保存权重,也可对Trainer进行继承来保存你想要的参数等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

tangjunjun-owen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值