OLMo系列——预训练part1(olmo/train.py)

CLASS 1

@dataclass
class SpeedMonitor:
    cfg: SpeedMonitorConfig
    start_times: Deque[float] = field(default_factory=lambda: deque([]))
    global_total_tokens: int = 0
    device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([]))

    def batch_start(self, global_total_tokens: int, device_batch_num_tokens: int, record: bool = True) -> None:
        self.global_total_tokens = global_total_tokens
        if record:
            if len(self.start_times) >= self.cfg.window_size:
                self.start_times.popleft()
                self.device_interval_tokens.popleft()
            self.start_times.append(time.monotonic())
            self.device_interval_tokens.append(device_batch_num_tokens)

    def reset(self) -> None:
        self.start_times.clear()
        self.device_interval_tokens.clear()

    def check(self) -> Dict[str, float]:
        metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
        if self.start_times:
            interval_seconds = time.monotonic() - self.start_times[0]
            interval_batches = len(self.start_times)
            interval_tokens = sum(self.device_interval_tokens)
            metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
            metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
        return metrics
  1. @dataclass: 装饰器,用于将类声明为数据类,自动生成 initrepr 等方法,简化类的定义。
  2. class SpeedMonitor:定义了一个名为 SpeedMonitor 的 Python 类,用于监测训练速度(throughput)
  3. cfg: SpeedMonitorConfig: 类属性,用于存储 SpeedMonitor 的配置信息,类型为 SpeedMonitorConfig。

start_times: Deque[float] = field(default_factory=lambda: deque([])): 类属性,双端队列,存储每个批次开始时的时间戳,初始为空。

global_total_tokens: int = 0: 类属性,记录全局总的标记数,初始为0。

device_interval_tokens: Deque[int] = field(default_factory=lambda: deque([])): 类属性,双端队列,存储每个设备上的批次中的标记数,初始为空。4. 批次开始时

def batch_start(self, global_total_tokens: int, device_batch_num_tokens: int, record: bool = True) -> None:
        self.global_total_tokens = global_total_tokens
        if record:
            if len(self.start_times) >= self.cfg.window_size:
                self.start_times.popleft()
                self.device_interval_tokens.popleft()
            self.start_times.append(time.monotonic())
            self.device_interval_tokens.append(device_batch_num_tokens)

batch_start 方法:用于在每个批次开始时调用。接受参数:global_total_tokens: int:当前全局总的标记数。device_batch_num_tokens: int:当前批次在设备上的标记数。record: bool = True:一个布尔值,指示是否记录此批次的信息,默认为 True。

在 batch_start 方法内:更新 global_total_tokens 属性。如果 record 为 True,则将当前时间戳和设备上的批次标记数添加到对应的双端队列中。如果队列长度超过配置的窗口大小(self.cfg.window_size),则从左侧弹出最老的时间戳和设备上的批次标记数。

def reset(self) -> None: self.start_times.clear() self.device_interval_tokens.clear()

reset 方法:清空 start_times 和 device_interval_tokens 双端队列,用于在某些情况下重置监测器。 7.check 方法:

def check(self) -> Dict[str, float]:
     metrics: Dict[str, float] = {"throughput/total_tokens": self.global_total_tokens}
     if self.start_times:
         interval_seconds = time.monotonic() - self.start_times[0]
         interval_batches = len(self.start_times)
         interval_tokens = sum(self.device_interval_tokens)
         metrics["throughput/device/tokens_per_second"] = interval_tokens / interval_seconds
         metrics["throughput/device/batches_per_second"] = interval_batches / interval_seconds
     return metrics

用于检查并返回当前的监测指标。返回一个字典,包含以下指标:“throughput/total_tokens”:全局总的标记数。“throughput/device/tokens_per_second”:每秒处理的设备标记数。“throughput/device/batches_per_second”:每秒处理的设备批次数。

CLASS 2

@dataclass
class LRMonitor:
 optim: torch.optim.Optimizer

 def check(self) -> Dict[str, float]:
     lrs = [group["lr"] for group in self.optim.param_groups]
     return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}

定义了一个名为 LRMonitor 的 Python 类,用于监测优化器中每个参数组的学习率1.optim: torch.optim.Optimizer: 类属性,用于存储一个 PyTorch 优化器对象,表示 LRMonitor 类要监测的优化器。2.def check(self) -> Dict[str, float]得到一个包含各参数组学习率的字典,方便监测和记录学习率的变化。。3.lrs = [group[“lr”] for group in self.optim.param_groups]: 获取优化器中每个参数组的学习率,并将其存储在列表 lrs 中。4.return {f"optim/learning_rate_group{idx}": lr for idx, lr in enumerate(lrs)}: 构建并返回一个字典,其中键为学习率的组索引,值为对应的学习率。使用了字典推导式和 enumerate 函数,将学习率与其对应的组索引配对。参数组:优化器中的参数被组织成一个或多个参数组。每个参数组都可以具有不同的学习率、权重衰减等优化超参数。这种组织结构允许在训练过程中对不同的参数应用不同的优化策略,从而提高模型的灵活性。在 PyTorch 的优化器中,一个参数组通常由一个字典表示,包含了以下一些关键字:params: 包含了当前参数组中需要优化的参数,通常是模型的权重和偏置。lr (learning rate): 学习率,表示对该参数组中参数应用的学习率。weight_decay: 权重衰减,用于对参数应用额外的权重衰减(L2正则化)。betas: 用于优化器的一些算法(如 Adam)中的动量和梯度平方项的衰减因子。等等,具体取决于使用的优化器和其配置。

CLASS 3:

@dataclass
class Trainer:
 cfg: TrainConfig
 model: Olmo
 fsdp_model: FSDP
 optim: Optimizer
 scheduler: Scheduler
 train_loader: DataLoader
 device: torch.device
 evaluators: List[Evaluator]
 epoch: Optional[int] = None
 global_step: int = 0
 global_train_examples_seen_this_epoch: int = 0
 """Tracks the global number of training examples seen in the current epoch for the purpose of restoring
 the data loader position on restarts."""
 global_train_tokens_seen: int = 0
 """Tracks the global total number of tokens trained on."""
 checkpoints: List[Path] = field(default_factory=list)
 unsharded_checkpoints: List[Path] = field(default_factory=list)
 ephemeral_checkpoints: List[Path] = field(default_factory=list)
 min_train_loss: float = float("inf")
 cur_train_loss: float = float("inf")
 indices_file: Optional[TextIO] = None
 _start_time: float = 0.0

 @property
 def dataset(self) -> IterableDataset:
     assert isinstance(self.train_loader.dataset, IterableDataset)
     return self.train_loader.dataset

 @property
 def tokens_per_batch(self) -> int:
     return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length

 @property
 def batches_per_epoch(self) -> int:
     return self.dataset.total_size // self.cfg.global_train_batch_size

 @property
 def max_epochs(self) -> int:
     if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
         return int(self.cfg.max_duration[:-2].strip())
     else:
         return 1

 @property
 def max_steps(self) -> int:
     if isinstance(self.cfg.max_duration, int):
         return self.cfg.max_duration
     elif isinstance(self.cfg.max_duration, str):
         if self.cfg.max_duration.endswith("T"):
             # convert to float *first* to handle scientific notation
             max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
             tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
             steps_remaining = tokens_remaining // self.tokens_per_batch
             return self.global_step + steps_remaining
         elif self.cfg.max_duration.endswith("ep"):
             max_epochs = int(self.cfg.max_duration[:-2].strip())
             return max_epochs * self.batches_per_epoch
         else:
             # convert to float *first* to handle scientific notation
             return int(float(self.cfg.max_duration))
     else:
         raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

 @property
 def max_tokens(self) -> int:
     if isinstance(self.cfg.max_duration, int):
         return (
             self.global_train_tokens_seen
             + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
         )
     elif isinstance(self.cfg.max_duration, str):
         if self.cfg.max_duration.endswith("T"):
             # convert to float *first* to handle scientific notation
             return int(float(self.cfg.max_duration[:-1].strip()))
         elif self.cfg.max_duration.endswith("ep"):
             max_epochs = int(self.cfg.max_duration[:-2].strip())
             return max_epochs * self.batches_per_epoch * self.tokens_per_batch
         else:
             # convert to float *first* to handle scientific notation
             return (
                 self.global_train_tokens_seen
                 + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
             )
     else:
         raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

 @property
 def scheduler_current(self) -> int:
     if self.cfg.scheduler.units == SchedulerUnits.steps:
         return self.global_step
     elif self.cfg.scheduler.units == SchedulerUnits.tokens:
         return self.global_train_tokens_seen
     else:
         raise NotImplementedError(self.cfg.scheduler.units)

 @property
 def scheduler_max(self) -> int:
     if self.cfg.scheduler.units == SchedulerUnits.steps:
         return self.max_steps
     elif self.cfg.scheduler.units == SchedulerUnits.tokens:
         return self.max_tokens
     else:
         raise NotImplementedError(self.cfg.scheduler.units)

 def trainer_state_dict(self) -> Dict[str, Any]:
     return {
         "epoch": self.epoch,
         "global_step": self.global_step,
         "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
         "global_train_tokens_seen": self.global_train_tokens_seen,
         "world_size": get_world_size(),
         "checkpoints": self.checkpoints,
         "unsharded_checkpoints": self.unsharded_checkpoints,
         "ephemeral_checkpoints": self.ephemeral_checkpoints,
         "rng": {
             "python": random.getstate(),
             "numpy": np.random.get_state(),
             "torch": torch.random.get_rng_state(),
             "cuda": torch.cuda.get_rng_state(),
         },
     }

 def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
     # Checkpoint paths.
     self.checkpoints = [
         path
         for path in state_dict["checkpoints"]
         if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
     ]
     self.unsharded_checkpoints = [
         path
         for path in state_dict["unsharded_checkpoints"]
         if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
     ]
     self.ephemeral_checkpoints = [
         path
         for path in state_dict.get("ephemeral_checkpoints", [])
         if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
     ]

     # Dataset / dataloader position.
     checkpoint_epoch = state_dict.get("epoch", 0)
     self.global_step = state_dict["global_step"]
     self.global_train_examples_seen_this_epoch = state_dict.get(
         "global_train_examples_seen_this_epoch",
         state_dict.get(  # for backwards compatibility
             "global_train_examples_seen",
             state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
         ),
     )
     self.global_train_tokens_seen = state_dict.get(
         "global_train_tokens_seen",
         state_dict.get("global_data_step", self.global_step)  # for backwards compatibility
         * self.cfg.global_train_batch_size
         * self.cfg.model.max_sequence_length,
     )

     if not self.cfg.restore_dataloader:
         self.epoch = 0
         self.global_train_tokens_seen = 0
         self.global_train_examples_seen_this_epoch = 0
     elif self.epoch is None:
         self.epoch = checkpoint_epoch
     elif checkpoint_epoch != self.epoch:
         log.info(f"Starting new epoch (epoch = {self.epoch})")
         self.global_train_examples_seen_this_epoch = 0

     if self.cfg.fast_forward_batches:
         log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
         # Technically we don't "see" these batches that we fast-forward through, but we use
         # this variable to update the position of the dataset so we need to include them here.
         self.global_train_examples_seen_this_epoch += (
             self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
         )
         # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
         # that variable is meant to track the actual number of tokens trained on.

     if self.global_train_examples_seen_this_epoch > 0:
         assert isinstance(self.dataset, IterableDataset)
         log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
         self.dataset.start_index = self.global_train_examples_seen_this_epoch

     # Reset learning rate and weight decay to the values from the config, not the checkpoint.
     log.info("Resetting learning rate...")
     new_learning_rate = self.scheduler.get_lr(
         self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
     )
     for group in self.optim.param_groups:
         group["lr"] = new_learning_rate
         group["initial_lr"] = self.cfg.optimizer.learning_rate
         if "weight_decay" in group and group["weight_decay"] > 0.0:
             group["weight_decay"] = self.cfg.optimizer.weight_decay

     # RNG states.
     if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
         log.info("Restoring RNG states...")
         rng_state = state_dict["rng"]
         self.restore_rng_state(rng_state)
     else:
         log.warning(
             "Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
             "This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
             "with a different world size. If that's the case you can safely ignore this warning."
         )

 def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
     random.setstate(rng_state["python"])
     np.random.set_state(rng_state["numpy"])
     torch.set_rng_state(rng_state["torch"])
     torch.cuda.set_rng_state(rng_state["cuda"])

 def _save_checkpoint(
     self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
 ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     if checkpoint_type == CheckpointType.sharded:
         suffix = ""
         current_checkpoints = self.checkpoints
         link_latest = get_fs_local_rank() == 0
         num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
     elif checkpoint_type == CheckpointType.unsharded:
         suffix = "-unsharded"
         current_checkpoints = self.unsharded_checkpoints
         link_latest = get_global_rank() == 0
         num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
     elif checkpoint_type == CheckpointType.sharded_ephemeral:
         suffix = ""
         current_checkpoints = self.ephemeral_checkpoints
         link_latest = get_fs_local_rank() == 0
         num_checkpoints_to_keep = 1
     else:
         raise NotImplementedError(checkpoint_type)

     # Zero-gradients to avoid gathering them.
     self.optim.zero_grad(set_to_none=True)

     # Flush data indices file.
     # TODO: upload the indices files?
     if self.indices_file is not None:
         self.indices_file.flush()

     checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
     remote_checkpoint_dir: Optional[str] = None
     if self.cfg.remote_save_folder is not None:
         remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
     current_checkpoints.append(checkpoint_dir)

     # Save the checkpoint.
     try:
         checkpointer.save_checkpoint(
             checkpoint_dir,
             self.fsdp_model,
             self.optim,
             self.trainer_state_dict(),
             upload_to=remote_checkpoint_dir,
         )
     except FileExistsError:
         raise OlmoConfigurationError(
             f"Checkpoint for step {self.global_step} already exists, use --save-overwrite to overwrite it"
         )

     if link_latest:
         # Link to 'latest'.
         latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
         latest_path.unlink(missing_ok=True)
         try:
             latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
         except FileExistsError:
             # Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
             # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
             # file-systems.
             if latest_path.resolve().name != checkpoint_dir.name:
                 raise

     # Remove old checkpoints.
     if num_checkpoints_to_keep > 0:
         while len(current_checkpoints) > num_checkpoints_to_keep:
             self.remove_checkpoint(0, checkpoint_type)

     barrier()

     if remote_checkpoint_dir is not None:
         return remote_checkpoint_dir, checkpoint_dir
     else:
         return checkpoint_dir, None

 def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     checkpointer = build_sharded_checkpointer(self.cfg)
     return self._save_checkpoint(checkpointer, CheckpointType.sharded)

 def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     checkpointer = build_sharded_checkpointer(self.cfg)
     return self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)

 def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
     oldest_checkpoint = checkpoints.pop(idx)
     barrier()
     if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
         shutil.rmtree(oldest_checkpoint, ignore_errors=True)
         latest_path = Path(self.cfg.save_folder) / "latest"
         if latest_path.resolve() == oldest_checkpoint.resolve():
             latest_path.unlink()
     barrier()

 def remove_sharded_checkpoint(self, idx: int = 0):
     self._remove_sharded_checkpoint(idx, self.checkpoints)

 def remove_ephemeral_checkpoint(self, idx: int = 0):
     self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)

 def restore_sharded_checkpoint(
     self,
     load_path: PathOrStr,
     local_cache: Optional[PathOrStr] = None,
     *,
     load_optimizer_state: bool = True,
     load_trainer_state: bool = True,
     sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
 ):
     # Zero-gradients to avoid gathering them.
     self.optim.zero_grad(set_to_none=True)
     checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
     trainer_state = checkpointer.restore_checkpoint(
         load_path,
         self.fsdp_model,
         self.optim,
         local_cache=local_cache,
         load_optimizer_state=load_optimizer_state,
     )
     if load_trainer_state:
         self.load_trainer_state_dict(trainer_state)
     barrier()

 def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     checkpointer = FullCheckpointer(self.cfg)
     return self._save_checkpoint(checkpointer, CheckpointType.unsharded)

 def remove_unsharded_checkpoint(self, idx: int = 0):
     barrier()
     oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
     if get_global_rank() == 0 and oldest_checkpoint.is_dir():
         shutil.rmtree(oldest_checkpoint, ignore_errors=True)
         latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
         if latest_path.resolve() == oldest_checkpoint.resolve():
             latest_path.unlink()
     barrier()

 def restore_unsharded_checkpoint(
     self,
     load_path: PathOrStr,
     local_cache: Optional[PathOrStr] = None,
     *,
     load_optimizer_state: bool = True,
     load_trainer_state: bool = True,
 ):
     # Zero-gradients to avoid gathering them.
     self.optim.zero_grad(set_to_none=True)
     checkpointer = FullCheckpointer(self.cfg)
     trainer_state = checkpointer.restore_checkpoint(
         load_path,
         self.fsdp_model,
         self.optim,
         local_cache=local_cache,
         load_optimizer_state=load_optimizer_state,
     )
     if load_trainer_state:
         self.load_trainer_state_dict(trainer_state)
     barrier()

 def save_checkpoint(
     self, checkpoint_type: CheckpointType = CheckpointType.sharded
 ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     if checkpoint_type == CheckpointType.sharded:
         return self.save_sharded_checkpoint()
     elif checkpoint_type == CheckpointType.unsharded:
         return self.save_unsharded_checkpoint()
     elif checkpoint_type == CheckpointType.sharded_ephemeral:
         return self.save_ephemeral_checkpoint()
     else:
         raise NotImplementedError(checkpoint_type)

 def restore_checkpoint(
     self,
     load_path: PathOrStr,
     *,
     checkpoint_type: Optional[CheckpointType] = None,
     local_cache: Optional[PathOrStr] = None,
     load_optimizer_state: bool = True,
     load_trainer_state: bool = True,
     sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
 ):
     if checkpoint_type == CheckpointType.unsharded or (
         checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
     ):
         self.restore_unsharded_checkpoint(
             load_path,
             local_cache=local_cache,
             load_optimizer_state=load_optimizer_state,
             load_trainer_state=load_trainer_state,
         )
     elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
         self.restore_sharded_checkpoint(
             load_path,
             local_cache=local_cache,
             load_optimizer_state=load_optimizer_state,
             load_trainer_state=load_trainer_state,
             sharded_checkpointer=sharded_checkpointer,
         )
     elif checkpoint_type is not None:
         raise NotImplementedError(checkpoint_type)

 def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
     if checkpoint_type == CheckpointType.sharded:
         self.remove_sharded_checkpoint(idx=idx)
     elif checkpoint_type == CheckpointType.unsharded:
         self.remove_unsharded_checkpoint(idx=idx)
     elif checkpoint_type == CheckpointType.sharded_ephemeral:
         self.remove_ephemeral_checkpoint(idx=idx)
     else:
         raise NotImplementedError(checkpoint_type)

 def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
     # Labels are just input IDs shifted to the left (first item is ignored).
     labels, label_mask, attention_mask = (
         batch["input_ids"].clone(),
         batch.get("label_mask"),
         batch.get("attention_mask"),
     )
     if label_mask is not None:
         labels.masked_fill_(~label_mask, -100)
     if attention_mask is not None:
         labels.masked_fill_(attention_mask == 0.0, -100)
     return labels[..., 1:].contiguous()

 def model_forward(
     self, batch: Dict[str, Any], loss_reduction: str = "mean"
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     # shape: (batch_size, seq_len, vocab_size)
     logits = self.fsdp_model(
         input_ids=batch["input_ids"],
         attention_mask=batch.get("attention_mask"),
         attention_bias=batch.get("attention_bias"),
     ).logits
     logits_for_loss = logits[..., :-1, :].contiguous()
     # shape: (batch_size * seq_len, vocab_size)
     logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
     # shape: (batch_size, seq_len)
     labels = self.get_labels(batch)
     # shape: (batch_size * seq_len,)
     labels = labels.view(-1)
     ce_loss = F.cross_entropy(logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction)
     if loss_reduction == "none":
         # Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
         ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
     return ce_loss, logits

 def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     # Split into micro-batches.
     micro_batches = self.split_batch(batch)

     # In case this helps with memory utilization.
     del batch

     ce_batch_loss = torch.tensor(0.0, device=self.device)
     z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
     for micro_batch in micro_batches:
         with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
             # Run forward pass.
             ce_loss, logits = self.model_forward(micro_batch)
             ce_loss = ce_loss / len(micro_batches)

             # In case this helps with memory utilization.
             del micro_batch

             # Update overall CE batch loss.
             ce_batch_loss += ce_loss.detach()

             # Get loss to optimize for.
             if self.cfg.softmax_auxiliary_loss:
                 z_squared = logits.logsumexp(-1).pow(2).mean()
                 z_loss = 1e-4 * z_squared / len(micro_batches)
                 loss = ce_loss + z_loss

                 # Update overall Z batch loss.
                 z_batch_loss += z_loss.detach()
             else:
                 loss = ce_loss

             del logits

         # Run backward pass.
         loss.backward()

     return ce_batch_loss, z_batch_loss

 def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
     metrics: Dict[str, float] = {}

     # Write data-indices to file.
     if self.indices_file is not None and "index" in batch:
         indices = "\t".join(str(int(i)) for i in batch["index"])
         self.indices_file.write(f"{self.global_step}\t{indices}\n")

     # Zero-gradients.
     self.optim.zero_grad(set_to_none=True)

     # Move tensors to the right device.
     batch = move_to_device(batch, self.device)

     # Run forward-backward pass.
     ce_batch_loss, z_batch_loss = self.train_batch(batch)

     # Collect loss, potentially reducing over all ranks.
     if reduce_global_loss:
         dist.reduce(ce_batch_loss, 0)
         ce_batch_loss.div_(get_world_size())
         if z_batch_loss is not None:
             dist.reduce(z_batch_loss, 0)
             z_batch_loss.div_(get_world_size())

     # Clip gradient norms and collect param/gradient/optim metrics.
     should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
     optim_metrics = self.optim.clip_grads_and_collect_metrics(
         self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
     )

     # Adjust the learning rate.
     for group in self.optim.param_groups:
         # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
         # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
         # the corresponding values from `self.cfg`.
         group["lr"] = self.scheduler.get_lr(
             self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
         )
         group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
             self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
         )
         group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
             self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
         )

     # Optimizer step.
     self.optim.step()

     # Collect metrics and check for NaN loss.
     # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
     if torch.isnan(ce_batch_loss):
         raise ValueError("nan loss encountered")
     if z_batch_loss is not None and torch.isnan(z_batch_loss):
         raise ValueError("nan loss encountered")
     for key, value in optim_metrics.items():
         metrics[f"optim/{key}"] = value.item()
     self.cur_train_loss = ce_batch_loss.item()
     self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
     metrics["train/CrossEntropyLoss"] = self.cur_train_loss
     metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
     if z_batch_loss is not None:
         metrics["train/ZLoss"] = z_batch_loss.item()

     # Maybe collect post-step optimizer-specific metrics.
     if should_log_optim_metrics_this_step:
         optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
         for key, value in optim_metrics.items():
             metrics[f"optim/{key}"] = value.item()

     return metrics

 def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
     with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
         ce_loss, logits = self.model_forward(batch, loss_reduction="none")
     return ce_loss.mean(dim=-1), logits

 def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
     # Move tensors to the right device.
     batch = move_to_device(batch, self.device)

     # Run forward pass.
     with torch.no_grad():  # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
         ce_loss, logits = self.eval_batch(batch)

     # Update metrics.
     evaluator.update_metrics(
         batch, ce_loss, logits
     )  # batch includes all keys that the downstream evaluation needs

     barrier()

 def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
     microbatch_size = self.cfg.device_train_microbatch_size
     batch_size = batch["input_ids"].shape[0]
     if batch_size <= microbatch_size:
         return [batch]
     else:
         micro_batches = {}
         for key, value in batch.items():
             if isinstance(value, torch.Tensor):
                 micro_batches[key] = value.split(microbatch_size, dim=0)
             elif isinstance(value, list):
                 micro_batches[key] = [
                     value[microbatch_size * i : microbatch_size * i + microbatch_size]
                     for i in range(math.ceil(batch_size / microbatch_size))
                 ]
             else:
                 raise ValueError(f"unexpected item in batch: '{key}={value}'")
         return [
             {key: value[i] for key, value in micro_batches.items()}  # type: ignore
             for i in range(len(micro_batches["input_ids"]))
         ]

 def system_metrics(self) -> Dict[str, float]:
     metrics = {}
     if self.global_step < 3 or self.global_step % 10 == 0:
         peak_gpu_mb = peak_gpu_memory()
         if peak_gpu_mb is not None:
             metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
     return metrics

 def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
     def format_float(value: float) -> str:
         if value < 0.0001:
             return str(value)  # scientific notation
         elif value > 1000:
             return f"{int(value):,d}"
         elif value > 100:
             return f"{value:.1f}"
         elif value > 10:
             return f"{value:.2f}"
         elif value > 1:
             return f"{value:.3f}"
         else:
             return f"{value:.4f}"

     log.info(
         f"{prefix}\n"
         + "\n".join(
             [
                 f"    {name}={format_float(value)}"
                 for name, value in metrics.items()
                 if not name.startswith("optim/")  # there's too many optimizer metrics
             ]
         )
     )

 def should_log_optim_metrics_this_step(self) -> bool:
     if self.cfg.wandb is None:
         # We only log optimizer-specific metrics to W&B, since there are usually too many metrics
         # to log to the console.
         return False
     optim_log_interval = self.cfg.optimizer.metrics_log_interval
     if optim_log_interval is None:
         optim_log_interval = self.cfg.wandb.log_interval
     else:
         optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
     return self.global_step % optim_log_interval == 0

 def should_log_this_step(self) -> bool:
     if self.global_step % self.cfg.console_log_interval == 0:
         return True
     elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
         return True
     else:
         return False

 def eval(self) -> Dict[str, Any]:
     # Zero gradients and set model to 'eval' mode.
     self.optim.zero_grad(set_to_none=True)
     self.fsdp_model.eval()

     eval_metrics = {}
     for evaluator in self.evaluators:
         log.info(f"Running evaluation for '{evaluator.label}'...")

         # Reset metrics.
         evaluator.reset_metrics()

         # Initialize data loader iterator.
         eval_batches = iter(evaluator.eval_loader)

         # Adjust how many batches to evaluate on.
         num_eval_batches = (
             evaluator.subset_num_batches
             if evaluator.subset_num_batches is not None
             else self.cfg.eval_subset_num_batches
         )
         if num_eval_batches > 0:
             num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
             eval_batches = islice(eval_batches, num_eval_batches)

         # Run model over batches.
         for eval_step, eval_batch in enumerate(eval_batches):
             self.eval_step(eval_batch, evaluator)

             # Log to console.
             if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
                 log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")

         # Get final metrics.
         metrics = evaluator.compute_metrics()
         eval_metrics.update(metrics)
         self.log_metrics_to_console(f"{evaluator.label}", metrics)

         del eval_batches

     return eval_metrics

 def check_if_cancelled(self) -> Tuple[bool, int]:
     should_cancel = False
     cancel_reason: Optional[str] = None
     extra_steps = 0
     if get_global_rank() == 0:
         if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
             # First check if we've reached the training time limit.
             should_cancel = True
             cancel_reason = "time limit reached"
             extra_steps = self.cfg.extra_steps_after_cancel
         elif (
             self.cfg.early_stopping_factor is not None
             and self.global_step > self.cfg.scheduler.t_warmup
             and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
         ):
             # Next check if early stopping loss criteria is met.
             should_cancel = True
             cancel_reason = "early stopping from loss increase"
         elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
             # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
             # We won't see it in the run object. So we have to use the import/export API to check.
             from requests.exceptions import RequestException

             try:
                 api = wandb.Api(api_key=api_key)
                 run = api.run(wandb.run.path)
                 for tag in run.tags or []:
                     if tag.lower() in {"cancel", "canceled", "cancelled"}:
                         should_cancel = True
                         cancel_reason = "Weights & Biases tag"
                         extra_steps = self.cfg.extra_steps_after_cancel
                         break
             except RequestException:
                 pass

     run_canceled = synchronize_flag(should_cancel, self.device)
     if run_canceled and cancel_reason is not None:
         extra_steps = synchronize_value(extra_steps, self.device)
         if extra_steps > 0:
             log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
         else:
             log.warning(f"Run canceled due to {cancel_reason}")

     return run_canceled, extra_steps

 def fit(self):
     self._start_time = time.time()

     if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
         eval_metrics = self.eval()
         if wandb.run is not None:
             wandb.log(eval_metrics, step=self.global_step)

     # Set model to 'train' mode.
     self.fsdp_model.train()

     # Initialize monitors.
     assert self.cfg.device_train_batch_size is not None
     speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
     lr_monitor = LRMonitor(self.optim)

     # Log system metrics at the start of training.
     sys_metrics = self.system_metrics()
     if sys_metrics:
         self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
         if wandb.run is not None:
             wandb.log(sys_metrics, step=0)

     # Python Profiler stuff
     if self.cfg.python_profiling:
         python_profiler = cProfile.Profile()
     else:
         python_profiler = None

     # PyTorch Profiler stuff
     if self.cfg.torch_profiling and get_global_rank() == 0:
         from torch.profiler import schedule

         profiling_schedule = schedule(wait=1, warmup=5, active=3)

         def on_trace_ready(p):
             profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
             profiler_output_dir.mkdir(exist_ok=True)

             output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
             log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
             output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
             log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")

             p.export_chrome_trace(
                 str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
             )
             if self.cfg.remote_save_folder is not None:
                 upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
                 log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
                 upload(trace_path, f"{upload_folder}/{trace_path.name}")

         from torch.profiler import ProfilerActivity

         torch_profiler = torch.profiler.profile(
             activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
             record_shapes=False,
             profile_memory=False,
             with_stack=True,
             schedule=profiling_schedule,
             on_trace_ready=on_trace_ready,
         )
         del profiling_schedule
     else:
         import contextlib

         torch_profiler = contextlib.nullcontext()

     # Train.
     first_batch: bool = True
     cancel_initiated: bool = False
     stop_at: Optional[int] = self.cfg.stop_at
     save_checkpoints: bool = True

     with torch_profiler as p:
         for epoch in range(self.epoch or 0, self.max_epochs):
             for batch in self.train_loader:
                 # Bookkeeping.
                 # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
                 # batches see the same number of tokens, which should be the case for language model pre-training
                 # (at least when drop_last=True).
                 # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
                 # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
                 # fail loudly.
                 batch_size, seq_len = batch["input_ids"].shape
                 assert seq_len == self.cfg.model.max_sequence_length
                 assert batch_size == self.cfg.device_train_batch_size
                 global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
                 self.global_step += 1
                 self.global_train_examples_seen_this_epoch += global_batch_size
                 self.global_train_tokens_seen += global_batch_size * seq_len
                 speed_monitor.batch_start(
                     self.global_train_tokens_seen,
                     batch_size * seq_len,  # num tokens in batch for this device
                     # We start monitoring speed after the first batch since the first
                     # batch might be an outlier due to compiling and other initialization overhead.
                     record=not first_batch,
                 )

                 should_log_this_step = self.should_log_this_step()

                 # Run train step on batch.
                 metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)

                 # Maybe collect other metrics.
                 if should_log_this_step:
                     # Speed metrics.
                     metrics.update(speed_monitor.check())
                     # System metrics.
                     metrics.update(self.system_metrics())
                     # Learning rate metrics.
                     metrics.update(lr_monitor.check())

                 # Log metrics to console.
                 if self.global_step % self.cfg.console_log_interval == 0:
                     self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

                 # Log metrics to W&B.
                 if (
                     wandb.run is not None
                     and self.cfg.wandb is not None
                     and self.global_step % self.cfg.wandb.log_interval == 0
                 ):
                     wandb.log(metrics, step=self.global_step)

                 # Check if/when run should be canceled.
                 if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                     cancel_initiated, extra_steps = self.check_if_cancelled()
                     if cancel_initiated:
                         stop_at = (
                             self.global_step + extra_steps
                             if stop_at is None
                             else min(self.global_step + extra_steps, stop_at)
                         )

                 # Maybe save sharded checkpoint.
                 if save_checkpoints and (
                     cancel_initiated
                     or (
                         self.global_step % self.cfg.save_interval == 0
                         and self.cfg.save_num_checkpoints_to_keep != 0
                     )
                 ):
                     log.info("Saving checkpoint...")
                     checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
                     log.info(f"Checkpoint saved to {checkpoint_path}")

                     # Remove any ephemeral checkpoints.
                     while self.ephemeral_checkpoints:
                         self.remove_ephemeral_checkpoint()

                     # Reset speed monitor so that we don't count the time taken to save checkpoints.
                     speed_monitor.reset()

                     # If the run was just canceled this will be the final checkpoint.
                     if cancel_initiated:
                         save_checkpoints = False
                 elif (
                     self.cfg.save_interval_ephemeral is not None
                     and self.global_step % self.cfg.save_interval_ephemeral == 0
                 ):
                     log.info("Saving ephemeral checkpoint...")
                     checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
                     log.info(f"Checkpoint saved to {checkpoint_path}")

                     # Reset speed monitor so that we don't count the time taken to save checkpoints.
                     speed_monitor.reset()

                 # Maybe save unsharded checkpoint.
                 if (
                     save_checkpoints
                     and self.cfg.save_interval_unsharded is not None
                     and self.global_step % self.cfg.save_interval_unsharded == 0
                     and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
                 ):
                     log.info("Saving unsharded checkpoint...")
                     checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
                     log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

                     # Reset speed monitor so that we don't count the time taken to save checkpoints.
                     speed_monitor.reset()

                 # Maybe run evaluations.
                 if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
                     eval_metrics = self.eval()

                     # Log metrics to W&B.
                     if wandb.run is not None:
                         wandb.log(eval_metrics, step=self.global_step)

                     # Reset speed monitor so that we don't count the time taken to run evaluations.
                     speed_monitor.reset()

                     # Reset model to 'train' mode.
                     self.fsdp_model.train()

                 # End of batch.
                 first_batch = False
                 if p is not None:
                     p.step()

                 if stop_at is not None and self.global_step >= stop_at:
                     break

                 # Python Profiler stuff
                 # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
                 if python_profiler is not None:
                     if self.global_step == 5:
                         python_profiler.enable()
                     elif self.global_step == 8:
                         python_profiler.disable()
                         python_profiler.print_stats(sort=SortKey.CUMULATIVE)
                         python_profiler = None
             else:
                 log.info("Training epoch complete")
                 self.epoch = epoch + 1
                 self.global_train_examples_seen_this_epoch = 0
                 if self.epoch < self.max_epochs:
                     self.dataset.reshuffle()
                 continue

             break

     # Save final checkpoint.
     if save_checkpoints:
         if self.cfg.save_interval_unsharded is not None:
             log.info("Saving final unsharded model checkpoint...")
             checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
             log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
         elif self.cfg.save_num_checkpoints_to_keep != 0:
             log.info("Saving final checkpoint...")
             checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
             log.info(f"Checkpoint saved to {checkpoint_path}")

 def close(self, exit_code: int = 0) -> None:
     if self.indices_file is not None:
         self.indices_file.flush()
         self.indices_file.close()
     if wandb.run is not None:
         wandb.finish(exit_code=exit_code, quiet=True)

 def __enter__(self) -> Trainer:
     return self

 def __exit__(self, exc_type, exc_val, exc_tb) -> None:
     del exc_val, exc_tb
     self.close(0 if exc_type is None else 1)
  1. 定义一个训练器

    @dataclass
    class Trainer:
    cfg: TrainConfig
    model: Olmo
    fsdp_model: FSDP
    optim: Optimizer
    scheduler: Scheduler
    train_loader: DataLoader
    device: torch.device
    evaluators: List[Evaluator]
    epoch: Optional[int] = None
    global_step: int = 0
    global_train_examples_seen_this_epoch: int = 0
    global_train_tokens_seen: int = 0
    checkpoints: List[Path] = field(default_factory=list)
    unsharded_checkpoints: List[Path] = field(default_factory=list)
    ephemeral_checkpoints: List[Path] = field(default_factory=list)
    min_train_loss: float = float(“inf”)
    cur_train_loss: float = float(“inf”)
    indices_file: Optional[TextIO] = None
    _start_time: float = 0.0

类属性包括:

  • cfg: TrainConfig: 用于存储训练配置的对象。
  • model: Olmo: 用于存储训练的深度学习模型对象。
  • fsdp_model: FSDP: 用于存储深度学习模型的 FSDP 封装对象。
  • optim: Optimizer: 用于存储优化器对象。
  • scheduler: Scheduler: 用于存储学习率调度器对象。
  • train_loader: DataLoader: 用于存储训练数据加载器对象。
  • device: torch.device: 用于存储指定的设备。
  • evaluators: List[Evaluator]: 用于存储评估器对象的列表。
  • epoch: Optional[int] = None: 用于存储当前训练的轮数,默认为 None
  • global_step: int = 0: 用于跟踪全局步数,初始值为 0。
  • global_train_examples_seen_this_epoch: int = 0: 用于跟踪当前轮次中已经处理的训练样本数量。
  • global_train_tokens_seen: int = 0: 用于跟踪全局训练中已经处理的标记数量。
  • checkpoints: List[Path] = field(default_factory=list): 存储已保存的检查点文件的路径列表。
  • unsharded_checkpoints: List[Path] = field(default_factory=list): 存储未分片的检查点文件的路径列表。
  • ephemeral_checkpoints: List[Path] = field(default_factory=list): 存储临时检查点文件的路径列表。
  • min_train_loss: float = float("inf"): 存储训练过程中的最小损失值,默认为正无穷。
  • cur_train_loss: float = float("inf"): 存储当前训练损失值,默认为正无穷。
  • indices_file: Optional[TextIO] = None: 用于存储数据索引文件的文本 I/O 对象,可为 None
  • _start_time: float = 0.0: 用于存储训练开始的时间戳,初始值为 0.0。

2.基础属性

@property
    def dataset(self) -> IterableDataset:
        assert isinstance(self.train_loader.dataset, IterableDataset)
        return self.train_loader.dataset

    @property
    def tokens_per_batch(self) -> int:
        return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length

    @property
    def batches_per_epoch(self) -> int:
        return self.dataset.total_size // self.cfg.global_train_batch_size

    @property
    def max_epochs(self) -> int:
        if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
            return int(self.cfg.max_duration[:-2].strip())
        else:
            return 1

    @property
    def max_steps(self) -> int:
        if isinstance(self.cfg.max_duration, int):
            return self.cfg.max_duration
        elif isinstance(self.cfg.max_duration, str):
            if self.cfg.max_duration.endswith("T"):
                # convert to float *first* to handle scientific notation
                max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
                tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
                steps_remaining = tokens_remaining // self.tokens_per_batch
                return self.global_step + steps_remaining
            elif self.cfg.max_duration.endswith("ep"):
                max_epochs = int(self.cfg.max_duration[:-2].strip())
                return max_epochs * self.batches_per_epoch
            else:
                # convert to float *first* to handle scientific notation
                return int(float(self.cfg.max_duration))
        else:
            raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

    @property
    def max_tokens(self) -> int:
        if isinstance(self.cfg.max_duration, int):
            return (
                self.global_train_tokens_seen
                + max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
            )
        elif isinstance(self.cfg.max_duration, str):
            if self.cfg.max_duration.endswith("T"):
                # convert to float *first* to handle scientific notation
                return int(float(self.cfg.max_duration[:-1].strip()))
            elif self.cfg.max_duration.endswith("ep"):
                max_epochs = int(self.cfg.max_duration[:-2].strip())
                return max_epochs * self.batches_per_epoch * self.tokens_per_batch
            else:
                # convert to float *first* to handle scientific notation
                return (
                    self.global_train_tokens_seen
                    + max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
                )
        else:
            raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

    @property
    def scheduler_current(self) -> int:
        if self.cfg.scheduler.units == SchedulerUnits.steps:
            return self.global_step
        elif self.cfg.scheduler.units == SchedulerUnits.tokens:
            return self.global_train_tokens_seen
        else:
            raise NotImplementedError(self.cfg.scheduler.units)

    @property
    def scheduler_max(self) -> int:
        if self.cfg.scheduler.units == SchedulerUnits.steps:
            return self.max_steps
        elif self.cfg.scheduler.units == SchedulerUnits.tokens:
            return self.max_tokens
        else:
            raise NotImplementedError(self.cfg.scheduler.units)

模型要加载的数据集,根据配置文件计算出每个批次的总token数,每个训练周期的批次数,最大训练周期数,最大训练步数,最大训练令牌数,当前调度器位置,调度器的最大位置

理解:

数据集好说,就是训练用的数据集

batch相当于把数据分组训练,例如所有的数据是一个苹果,能够根据面积(batchsize)划分成n个batch(n=Sapple/Sbatch),一个epoch相当于一个苹果,而step相当于根据你的胃去吃每吃掉一个batch,吃一个step就+1,这个step是根据你的胃来讲是有限的,计算的时候可以有多种方法,比方说多少个苹果,吃了多少体积的苹果,吃了多少口苹果,所以达到固定的step的时候,可以采取一定的措施,比方说调整姿势(学习率等等),等你达到最大容量(根据容量计算的step),就可以停止了。而当你选择在什么时候调整你的一些东西,比方说姿势的时候,即可以根据你吃了多少次,也可以根据你具体吃了多少个苹果

每个批次token数是根据设置的最大长度和batchsize相乘得到的

每个epoch有几个batch呢?

@property def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) return self.train_loader.dataset`

  • 作用: 获取训练数据集。
  • 说明: 使用 train_loaderdataset 属性,并确保其类型为 IterableDataset

@property def tokens_per_batch(self) -> int: return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length

  • 作用: 获取每个批次的总令牌数。
  • 说明: 通过训练配置 (cfg) 中的全局批次大小和模型的最大序列长度相乘得到。

@property def batches_per_epoch(self) -> int: return self.dataset.total_size // self.cfg.global_train_batch_size

  • 作用: 获取每个训练周期的批次数。
  • 说明: 通过数据集总大小除以全局批次大小得到。

@property def max_epochs(self) -> int: if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"): return int(self.cfg.max_duration[:-2].strip()) else: return 1

  • 作用: 获取最大训练周期数。
  • 说明: 如果最大训练时长是字符串并以 “ep” 结尾,则解析出整数。否则,默认为1。

`@property def max_steps(self) -> int:

  • 作用: 获取最大训练步数。
  • 说明: 通过对 cfg.max_duration 的不同类型进行处理,计算出最大训练步数。

`@property def max_tokens(self) -> int:

  • 作用: 获取最大训练令牌数。
  • 说明: 通过对 cfg.max_duration 的不同类型进行处理,计算出最大训练令牌数。

`@property def scheduler_current(self) -> int:

  • 作用: 获取当前调度器位置。
  • 说明: 根据调度器单位(步数或令牌),返回当前位置。

@property def scheduler_max(self) -> int: # ...(详细解释见下文)

  • 作用: 获取调度器的最大位置。
  • 说明: 根据调度器单位(步数或令牌),返回最大位置。

3.状态保存和加载

 def trainer_state_dict(self) -> Dict[str, Any]:
        return {
            "epoch": self.epoch,
            "global_step": self.global_step,
            "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
            "global_train_tokens_seen": self.global_train_tokens_seen,
            "world_size": get_world_size(),
            "checkpoints": self.checkpoints,
            "unsharded_checkpoints": self.unsharded_checkpoints,
            "ephemeral_checkpoints": self.ephemeral_checkpoints,
            "rng": {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "torch": torch.random.get_rng_state(),
                "cuda": torch.cuda.get_rng_state(),
            },
        }

    def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
        # Checkpoint paths.
        self.checkpoints = [
            path
            for path in state_dict["checkpoints"]
            if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
        ]
        self.unsharded_checkpoints = [
            path
            for path in state_dict["unsharded_checkpoints"]
            if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
        ]
        self.ephemeral_checkpoints = [
            path
            for path in state_dict.get("ephemeral_checkpoints", [])
            if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
        ]

        # Dataset / dataloader position.
        checkpoint_epoch = state_dict.get("epoch", 0)
        self.global_step = state_dict["global_step"]
        self.global_train_examples_seen_this_epoch = state_dict.get(
            "global_train_examples_seen_this_epoch",
            state_dict.get(  # for backwards compatibility
                "global_train_examples_seen",
                state_dict.get("global_data_step", self.global_step) * self.cfg.global_train_batch_size,
            ),
        )
        self.global_train_tokens_seen = state_dict.get(
            "global_train_tokens_seen",
            state_dict.get("global_data_step", self.global_step)  # for backwards compatibility
            * self.cfg.global_train_batch_size
            * self.cfg.model.max_sequence_length,
        )

        if not self.cfg.restore_dataloader:
            self.epoch = 0
            self.global_train_tokens_seen = 0
            self.global_train_examples_seen_this_epoch = 0
        elif self.epoch is None:
            self.epoch = checkpoint_epoch
        elif checkpoint_epoch != self.epoch:
            log.info(f"Starting new epoch (epoch = {self.epoch})")
            self.global_train_examples_seen_this_epoch = 0

        if self.cfg.fast_forward_batches:
            log.info(f"Fast-forwarding data loader by {self.cfg.fast_forward_batches:,d} steps")
            # Technically we don't "see" these batches that we fast-forward through, but we use
            # this variable to update the position of the dataset so we need to include them here.
            self.global_train_examples_seen_this_epoch += (
                self.cfg.fast_forward_batches * self.cfg.global_train_batch_size
            )
            # NOTE: on the other hand we don't add anything to 'self.global_train_tokens_seen' here because
            # that variable is meant to track the actual number of tokens trained on.

        if self.global_train_examples_seen_this_epoch > 0:
            assert isinstance(self.dataset, IterableDataset)
            log.info(f"Data loader will start at instance index {self.global_train_examples_seen_this_epoch:,d}")
            self.dataset.start_index = self.global_train_examples_seen_this_epoch

        # Reset learning rate and weight decay to the values from the config, not the checkpoint.
        log.info("Resetting learning rate...")
        new_learning_rate = self.scheduler.get_lr(
            self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
        )
        for group in self.optim.param_groups:
            group["lr"] = new_learning_rate
            group["initial_lr"] = self.cfg.optimizer.learning_rate
            if "weight_decay" in group and group["weight_decay"] > 0.0:
                group["weight_decay"] = self.cfg.optimizer.weight_decay

        # RNG states.
        if "rng" in state_dict and state_dict.get("world_size", get_world_size()) == get_world_size():
            log.info("Restoring RNG states...")
            rng_state = state_dict["rng"]
            self.restore_rng_state(rng_state)
        else:
            log.warning(
                "Trainer will not restore RNG states since the RNG states in the checkpoint are missing or invalid. "
                "This typically happens when restoring from an unsharded checkpoint or a checkpoint that was saved "
                "with a different world size. If that's the case you can safely ignore this warning."
            )
 def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
     random.setstate(rng_state["python"])
     np.random.set_state(rng_state["numpy"])
     torch.set_rng_state(rng_state["torch"])
     torch.cuda.set_rng_state(rng_state["cuda"])

这里理清一个概念,token组成example,n个example组成一个batch,n为batchsize,一个数据集分成m个batch,所有训练一遍为1个epoch,上述代码就是获取当前状态

trainer_state_dict:保存训练器的状态字典,

  • 保存了训练器的当前状态,包括当前训练的 epoch、全局步数、全局训练实例数、全局训练令牌数等信息。
  • 保存了当前训练使用的检查点(checkpoints)的路径列表。
  • 保存了当前训练使用的未分片检查点(unsharded checkpoints)的路径列表。
  • 保存了当前训练使用的短暂检查点(ephemeral checkpoints)的路径列表。
  • 保存了不同随机数生成器(RNG)的状态,包括 Python 内置的、NumPy 的、PyTorch 的、以及 CUDA 上的 RNG。

load_trainer_state_dict:加载状态字典

  • 加载训练器的状态字典,包括检查点路径、数据集/数据加载器位置、学习率和权重衰减的重置、以及 RNG 状态的恢复等。
  • 针对检查点路径,根据存储的路径信息更新当前训练器使用的检查点路径列表。
  • 针对数据集/数据加载器位置,根据存储的信息更新当前的 epoch、全局步数和全局训练实例数。
  • 快进操作通常用于模型训练的初始化阶段,以跳过一些训练步骤,加速模型达到一个合适的状态。快进的步数是在配置文件中设置的,根据具体的需求来确定是否使用以及使用多少步进行快进。
  • 重置学习率和权重衰减为配置文件中指定的值,而不是从检查点中读取。
  • 恢复 RNG 状态,包括 Python 内置的、NumPy 的、PyTorch 的、以及 CUDA 上的 RNG。

使得可以在训练过程中保存训练器的状态,以便在需要时加载并继续训练。

restore_rng_state:根据保存的数据恢复随机数字生成器,状态直接设置

4.保存

 def _save_checkpoint(
     self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
 ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
     if checkpoint_type == CheckpointType.sharded:
         suffix = ""
         current_checkpoints = self.checkpoints
         link_latest = get_fs_local_rank() == 0
         num_checkpoints_to_keep = self.cfg.save_num_checkpoints_to_keep
     elif checkpoint_type == CheckpointType.unsharded:
         suffix = "-unsharded"
         current_checkpoints = self.unsharded_checkpoints
         link_latest = get_global_rank() == 0
         num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
     elif checkpoint_type == CheckpointType.sharded_ephemeral:
         suffix = ""
         current_checkpoints = self.ephemeral_checkpoints
         link_latest = get_fs_local_rank() == 0
         num_checkpoints_to_keep = 1
     else:
         raise NotImplementedError(checkpoint_type)

     # Zero-gradients to avoid gathering them.
     self.optim.zero_grad(set_to_none=True)

     # Flush data indices file.
     # TODO: upload the indices files?
     if self.indices_file is not None:
         self.indices_file.flush()

     checkpoint_dir = Path(self.cfg.save_folder) / f"step{self.global_step}{suffix}"
     remote_checkpoint_dir: Optional[str] = None
     if self.cfg.remote_save_folder is not None:
         remote_checkpoint_dir = f"{self.cfg.remote_save_folder.rstrip('/')}/{checkpoint_dir.name}"
     current_checkpoints.append(checkpoint_dir)

     # Save the checkpoint.
     try:
         checkpointer.save_checkpoint(
             checkpoint_dir,
             self.fsdp_model,
             self.optim,
             self.trainer_state_dict(),
             upload_to=remote_checkpoint_dir,
         )
     except FileExistsError:
         raise OlmoConfigurationError(
             f"Checkpoint for step {self.global_step} already exists, use --save-overwrite to overwrite it"
         )

     if link_latest:
         # Link to 'latest'.
         latest_path = Path(self.cfg.save_folder) / f"latest{suffix}"
         latest_path.unlink(missing_ok=True)
         try:
             latest_path.symlink_to(checkpoint_dir.name, target_is_directory=True)
         except FileExistsError:
             # Same as above, caught when another (file-system) local rank 0 has already made the 'latest' symlink.
             # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
             # file-systems.
             if latest_path.resolve().name != checkpoint_dir.name:
                 raise

     # Remove old checkpoints.
     if num_checkpoints_to_keep > 0:
         while len(current_checkpoints) > num_checkpoints_to_keep:
             self.remove_checkpoint(0, checkpoint_type)

     barrier()

     if remote_checkpoint_dir is not None:
         return remote_checkpoint_dir, checkpoint_dir
     else:
         return checkpoint_dir, None

_save_checkpoint:保存训练的ckpt,

  1. 清零梯度:通过调用 self.optim.zero_grad(set_to_none=True) 来清零优化器的梯度。
  2. 刷新数据索引文件:如果有数据索引文件,通过调用 self.indices_file.flush() 来刷新文件。
  3. 构建检查点目录路径和远程检查点目录路径。
  4. 保存检查点:调用 checkpointer.save_checkpoint() 方法来保存检查点。如果设置了远程保存路径,检查点也会上传到远程位置。
  5. 链接到最新检查点:如果需要,在本地创建一个符号链接,将其指向最新保存的检查点目录。
  6. 删除旧的检查点:根据配置文件中设置的保留检查点数量,删除多余的检查点。
  7. 返回保存的检查点路径。

检查点的保存和恢复的关键部分,它们确保了训练过程的持久性和可恢复性。

get_fs_local_rank() get_global_rank() 这两个函数用于获取当前进程在分布式训练中的本地排名(local rank)和全局排名(global rank)。

  1. get_fs_local_rank():pythonCopy codedef get_fs_local_rank() -> int: return get_rank(group="fsdp")这个函数通过调用 get_rank(group="fsdp") 获取当前进程在名为 “fsdp” 的进程组中的本地排名。在分布式训练中,每个节点上的 GPU 设备都会分配一个本地排名。本地排名表示当前进程在当前节点上的 GPU 设备的相对位置。例如,如果一个节点上有两块 GPU,它们的本地排名可能是 0 和 1。

  2. get_global_rank():pythonCopy codedef get_global_rank() -> int: return get_rank()这个函数通过调用 get_rank() 获取当前进程在全局进程组中的排名。在分布式训练中,全局排名表示当前进程在整个分布式系统中的相对位置。每个进程都有一个唯一的全局排名,从 0 开始递增。

这些函数通常在分布式训练环境中使用,以便了解当前进程在分布式设置中的位置。在上下文中,它们用于确定是否在特定的进程组中(如 “fsdp”),以及获取全局排名,以便在训练中进行一些特定于分布式设置的操作。

创建一个名为 ‘latest’ 的符号链接,指向最新的检查点目录。这样做的好处是,在训练过程中,可以通过引用 ‘latest’ 符号链接来获取最新的检查点,而不需要明确指定特定的检查点目录。这对于在训练期间方便地使用最新的模型状态是很有用的。

5.ckpt操作

    def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
        checkpointer = build_sharded_checkpointer(self.cfg)
        return self._save_checkpoint(checkpointer, CheckpointType.sharded)

    def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
        checkpointer = build_sharded_checkpointer(self.cfg)
        return self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)

    def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
        oldest_checkpoint = checkpoints.pop(idx)
        barrier()
        if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
            shutil.rmtree(oldest_checkpoint, ignore_errors=True)
            latest_path = Path(self.cfg.save_folder) / "latest"
            if latest_path.resolve() == oldest_checkpoint.resolve():
                latest_path.unlink()
        barrier()

    def remove_sharded_checkpoint(self, idx: int = 0):
        self._remove_sharded_checkpoint(idx, self.checkpoints)

    def remove_ephemeral_checkpoint(self, idx: int = 0):
        self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)

    def restore_sharded_checkpoint(
        self,
        load_path: PathOrStr,
        local_cache: Optional[PathOrStr] = None,
        *,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
    ):
        # Zero-gradients to avoid gathering them.
        self.optim.zero_grad(set_to_none=True)
        checkpointer = build_sharded_checkpointer(self.cfg, name=sharded_checkpointer)
        trainer_state = checkpointer.restore_checkpoint(
            load_path,
            self.fsdp_model,
            self.optim,
            local_cache=local_cache,
            load_optimizer_state=load_optimizer_state,
        )
        if load_trainer_state:
            self.load_trainer_state_dict(trainer_state)
        barrier()

    def save_unsharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
        checkpointer = FullCheckpointer(self.cfg)
        return self._save_checkpoint(checkpointer, CheckpointType.unsharded)

    def remove_unsharded_checkpoint(self, idx: int = 0):
        barrier()
        oldest_checkpoint = self.unsharded_checkpoints.pop(idx)
        if get_global_rank() == 0 and oldest_checkpoint.is_dir():
            shutil.rmtree(oldest_checkpoint, ignore_errors=True)
            latest_path = Path(self.cfg.save_folder) / "latest-unsharded"
            if latest_path.resolve() == oldest_checkpoint.resolve():
                latest_path.unlink()
        barrier()

    def restore_unsharded_checkpoint(
        self,
        load_path: PathOrStr,
        local_cache: Optional[PathOrStr] = None,
        *,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
    ):
        # Zero-gradients to avoid gathering them.
        self.optim.zero_grad(set_to_none=True)
        checkpointer = FullCheckpointer(self.cfg)
        trainer_state = checkpointer.restore_checkpoint(
            load_path,
            self.fsdp_model,
            self.optim,
            local_cache=local_cache,
            load_optimizer_state=load_optimizer_state,
        )
        if load_trainer_state:
            self.load_trainer_state_dict(trainer_state)
        barrier()

    def save_checkpoint(
        self, checkpoint_type: CheckpointType = CheckpointType.sharded
    ) -> Tuple[PathOrStr, Optional[PathOrStr]]:
        if checkpoint_type == CheckpointType.sharded:
            return self.save_sharded_checkpoint()
        elif checkpoint_type == CheckpointType.unsharded:
            return self.save_unsharded_checkpoint()
        elif checkpoint_type == CheckpointType.sharded_ephemeral:
            return self.save_ephemeral_checkpoint()
        else:
            raise NotImplementedError(checkpoint_type)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        checkpoint_type: Optional[CheckpointType] = None,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        sharded_checkpointer: Optional[ShardedCheckpointerType] = None,
    ):
        if checkpoint_type == CheckpointType.unsharded or (
            checkpoint_type is None and str(load_path).rstrip("/").endswith("-unsharded")
        ):
            self.restore_unsharded_checkpoint(
                load_path,
                local_cache=local_cache,
                load_optimizer_state=load_optimizer_state,
                load_trainer_state=load_trainer_state,
            )
        elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
            self.restore_sharded_checkpoint(
                load_path,
                local_cache=local_cache,
                load_optimizer_state=load_optimizer_state,
                load_trainer_state=load_trainer_state,
                sharded_checkpointer=sharded_checkpointer,
            )
        elif checkpoint_type is not None:
            raise NotImplementedError(checkpoint_type)

    def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
        if checkpoint_type == CheckpointType.sharded:
            self.remove_sharded_checkpoint(idx=idx)
        elif checkpoint_type == CheckpointType.unsharded:
            self.remove_unsharded_checkpoint(idx=idx)
        elif checkpoint_type == CheckpointType.sharded_ephemeral:
            self.remove_ephemeral_checkpoint(idx=idx)
        else:
            raise NotImplementedError(checkpoint_type)

save_sharded_checkpoint() 函数用于保存分片检查点。它首先构建了一个分片检查点管理器(checkpointer),然后调用了内部方法 _save_checkpoint() 来执行保存操作。最后,它返回了保存的检查点路径。

save_ephemeral_checkpoint() 函数用于保存瞬时的分片检查点。它与 save_sharded_checkpoint() 函数类似

_remove_sharded_checkpoint() 方法用于删除指定索引的分片检查点。它从给定的 checkpoints 列表中移除特定索引处的检查点,并在必要时删除文件系统中的相应目录。如果当前本地进程是第一个本地进程(get_fs_local_rank() == 0),并且被删除的检查点目录是最新的检查点目录,则同时删除符号链接文件。最后,通过调用 barrier() 函数来同步处理。保持最新的检查点目录的符号链接文件的一致性。符号链接文件通常是一个指向最新检查点目录的符号链接,以便在其他代码中可以简单地通过访问符号链接文件来引用最新的检查点。因此,当删除最新的检查点目录时,需要同时删除符号链接文件,以确保符号链接文件仍然指向正确的最新检查点。barrier() 函数的调用可能是为了确保在删除操作完成后,其他本地进程也能够感知到最新检查点目录的变化。barrier() 通常用于同步多个进程,确保它们在某个点上达到同步状态。

分片检查点的删除操作通常在以下情况下被调用:

  1. 保存新的检查点时: 在保存新的分片检查点时,可能会删除旧的检查点,以保持存储空间的有效利用。这是因为通常不需要保留所有的检查点,而是只保留一定数量的最新检查点。

  2. 在训练过程中的定期清理: 为了防止存储空间被过多的检查点占用,可能会定期进行清理操作,删除一些旧的检查点。

  3. 达到一定条件时触发: 可能会根据一些特定的条件来触发检查点的删除,例如系统存储空间不足或者用户设定的一些策略。

恢复权重

非分片的检查点恢复一般涉及以下步骤:

  1. 初始化模型和优化器: 首先,你需要初始化与模型相同结构的实例,并设置相同的权重。然后,你需要初始化与训练时相同配置的优化器。

  2. 加载检查点: 从之前保存的检查点文件中加载模型参数和优化器状态。这涉及到将模型权重和优化器参数从检查点文件中读取,并将它们设置到对应的模型和优化器实例中。

  3. 设置学习率和权重衰减: 如果训练时使用了学习率调度器,你可能需要重新设置学习率。此外,如果有使用权重衰减,也需要相应地设置权重衰减。

  4. 继续训练: 最后,你可以继续使用这个已经加载并配置好的模型和优化器进行训练。通常,这会涉及迭代数据集,并更新模型参数以最小化损失。

在代码中,restore_unsharded_checkpoint 函数就是实现了非分片检查点的恢复操作。它调用了 FullCheckpointerrestore_checkpoint 方法,该方法会加载模型权重和优化器状态。如果指定了 load_trainer_state=True,则会使用 load_trainer_state_dict 方法来加载与训练器相关的状态,例如全局步数等。最后,barrier() 函数用于同步各个进程,确保检查点加载完成后再继续执行。restore_sharded_checkpoint同理。在恢复分片检查点时,需要使用相应的分片检查点器(build_sharded_checkpointer)来确保正确地加载和同步分布在多个设备上的状态。

remove_checkpoint同理,接受两个参数:idx 表示要删除的检查点的索引,checkpoint_type 表示检查点的类型(分片、非分片、分片短暂)。根据 checkpoint_type 的值,该方法会分派调用不同类型检查点的删除方法。save同理,根据类型保存。

restore_checkpoint用于从检查点中还原训练状态方法,接受多个参数,包括 load_path(指定检查点路径)、checkpoint_type(检查点类型,可选)、local_cache(本地缓存路径,可选)、load_optimizer_stateload_trainer_state(用于指定是否加载优化器和训练器状态的标志)、sharded_checkpointer(用于分片检查点的特殊检查点器,可选),根据类型进行恢复。

6.token处理

    def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
        # Labels are just input IDs shifted to the left (first item is ignored).
        labels, label_mask, attention_mask = (
            batch["input_ids"].clone(),
            batch.get("label_mask"),
            batch.get("attention_mask"),
        )
        if label_mask is not None:
            labels.masked_fill_(~label_mask, -100)
        if attention_mask is not None:
            labels.masked_fill_(attention_mask == 0.0, -100)
        return labels[..., 1:].contiguous()

    def model_forward(
        self, batch: Dict[str, Any], loss_reduction: str = "mean"
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # shape: (batch_size, seq_len, vocab_size)
        logits = self.fsdp_model(
            input_ids=batch["input_ids"],
            attention_mask=batch.get("attention_mask"),
            attention_bias=batch.get("attention_bias"),
        ).logits
        logits_for_loss = logits[..., :-1, :].contiguous()
        # shape: (batch_size * seq_len, vocab_size)
        logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1))
        # shape: (batch_size, seq_len)
        labels = self.get_labels(batch)
        # shape: (batch_size * seq_len,)
        labels = labels.view(-1)
        ce_loss = F.cross_entropy(logits_for_loss, labels, ignore_index=-100, reduction=loss_reduction)
        if loss_reduction == "none":
            # Reshape (batch_size * seq_len,) -> (batch_size, seq_len)
            ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1)
        return ce_loss, logits

首先,是get_labels(),参数是batch,batch是一个dict,具体的keys就是看dict设置。获得batch的数据后,返回有效的token,并去掉一个字符,一般是一些填充或者特殊字符。

get_labels 函数用于获取模型训练时的标签(即目标输出)。在语言建模任务中,标签通常是输入序列中的下一个词语。根据输入序列生成标签,并对标签进行一些处理,最终返回模型的目标输出。

这个函数主要是在model_forward里使用,我们看下在forward里是怎么用的

model_forward是用来计算损失和得到预测值的

logits = self.fsdp_model( input_ids=batch[“input_ids”], attention_mask=batch.get(“attention_mask”), attention_bias=batch.get(“attention_bias”), ).logits

这里理清一个概念:

logits的结果是每个词后面的一个词在词汇表上的概率,logits 的形状是 (batch_size, seq_len, vocab_size),表示每个位置对每个单词的预测分数。假设 seq_len 是 5,vocab_size 是 10,那么 logits 的形状就是 (batch_size, 5, 10)

logits[..., :-1, :] 将去掉序列的最后一个位置,因为我们在 labels 中也左移了一位,所以需要对应地去掉 logits 的最后一个位置。使用 .contiguous().view(-1, logits.size(-1)) 将 logits 和 labels 展平为一维数组,以便计算交叉熵损失。

举例:如果我输入的是“我爱学习”这一句话,那么batchsize=1,seq_len=4,词汇表大小是n的情况下,这里不考虑额外的添加字符

logits[0]的形状即为(3,n),n是词汇表维度,代表每个词汇的概率,3是代表1.“我”的下一个词的概率2.“我爱”下一个词的概率 3.“我爱学”下一个词的概率,因为我爱学习下一个没有词,所以logits不取-1,这样利用mask的方式进行训练,进行损失值的计算

使用 PyTorch 的交叉熵损失函数 F.cross_entropy 计算损失,其中 ignore_index=-100 表示忽略标签为 -100 的位置。

如果 loss_reduction 为 “none”,则需要重新整形损失张量,将其从 (batch_size * (seq_len - 1),) 重新变形为 (batch_size, seq_len - 1)

选择 reduction 参数通常取决于训练任务和具体需求:

  1. ‘mean’(平均):适用于训练损失的主要目标是获取每个样本的平均损失,对应于期望最小化平均损失的场景。

  2. ‘sum’(总和):适用于训练损失的主要目标是获取整个批次中所有样本的总损失,对应于期望最小化总损失的场景。

  3. ‘none’(无):适用于需要对每个样本的损失进行个别处理或分析的情况,对应于需要详细了解每个样本损失的场景。

一般来说,如果关注平均性能,可以选择 'mean';如果更关注总体性能,可以选择 'sum';如果需要详细分析每个样本的损失,可以选择 'none'

7.一个batch的训练

    def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Split into micro-batches.
        micro_batches = self.split_batch(batch)

        # In case this helps with memory utilization.
        del batch

        ce_batch_loss = torch.tensor(0.0, device=self.device)
        z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
        for micro_batch in micro_batches:
            with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
                # Run forward pass.
                ce_loss, logits = self.model_forward(micro_batch)
                ce_loss = ce_loss / len(micro_batches)

                # In case this helps with memory utilization.
                del micro_batch

                # Update overall CE batch loss.
                ce_batch_loss += ce_loss.detach()

                # Get loss to optimize for.
                if self.cfg.softmax_auxiliary_loss:
                    z_squared = logits.logsumexp(-1).pow(2).mean()
                    z_loss = 1e-4 * z_squared / len(micro_batches)
                    loss = ce_loss + z_loss

                    # Update overall Z batch loss.
                    z_batch_loss += z_loss.detach()
                else:
                    loss = ce_loss

                del logits

            # Run backward pass.
            loss.backward()

        return ce_batch_loss, z_batch_loss

train_batch方法,用于一个batch的训练

切分为微批次(Micro-Batches): 输入批次被切分为多个微批次,这样做可能是为了更有效地使用内存。

自动混合精度(Automatic Mixed Precision,AMP): 使用 torch.autocast 启用自动混合精度,这有助于提高训练速度,减少内存占用。dtype=self.cfg.autocast_precision 指定混合精度的数据类型。

前向传播(Forward Pass): 对每个微批次运行前向传播,获取交叉熵损失(ce_loss)和模型的输出 logits。

累计损失(Accumulate Loss): 累加交叉熵损失,这将用于更新模型参数。

获取辅助损失(Auxiliary Loss): 如果配置中启用了辅助损失(softmax_auxiliary_loss),则计算辅助损失(z_loss)并加到总损失上。辅助损失的计算包括对 logits 进行处理,计算 z 损失。

  1. 首先,对 logits 进行 softmax 操作,得到预测概率分布。
  2. 接着,对预测概率分布取对数,然后沿着最后一个维度求和,得到一个关于预测概率的标量值。
  3. 对该标量值进行平方操作,然后取其平均值,得到 z_squared
  4. 最后,根据 z_squared 计算 z_loss,并将其除以微批次的数量,以获得平均辅助损失。

z_batch_loss 的作用是在训练过程中帮助模型学习更好的表示。通常,辅助损失可以帮助模型学习更加鲁棒和有意义的表示,从而提高模型的泛化能力,辅助损失的设计是为了引入一些额外的正则化效果,以及在训练中更好地捕捉模型输出的一些全局性质。

最后运行反向传播,并在每个微批次上累积梯度。

8.step训练

    def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
        metrics: Dict[str, float] = {}

        # Write data-indices to file.
        if self.indices_file is not None and "index" in batch:
            indices = "\t".join(str(int(i)) for i in batch["index"])
            self.indices_file.write(f"{self.global_step}\t{indices}\n")

        # Zero-gradients.
        self.optim.zero_grad(set_to_none=True)

        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)

        # Run forward-backward pass.
        ce_batch_loss, z_batch_loss = self.train_batch(batch)

        # Collect loss, potentially reducing over all ranks.
        if reduce_global_loss:
            dist.reduce(ce_batch_loss, 0)
            ce_batch_loss.div_(get_world_size())
            if z_batch_loss is not None:
                dist.reduce(z_batch_loss, 0)
                z_batch_loss.div_(get_world_size())

        # Clip gradient norms and collect param/gradient/optim metrics.
        should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
        optim_metrics = self.optim.clip_grads_and_collect_metrics(
            self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
        )

        # Adjust the learning rate.
        for group in self.optim.param_groups:
            # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
            # we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
            # the corresponding values from `self.cfg`.
            group["lr"] = self.scheduler.get_lr(
                self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
            )
            group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
                self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
            )
            group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
                self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
            )

        # Optimizer step.
        self.optim.step()

        # Collect metrics and check for NaN loss.
        # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
        if torch.isnan(ce_batch_loss):
            raise ValueError("nan loss encountered")
        if z_batch_loss is not None and torch.isnan(z_batch_loss):
            raise ValueError("nan loss encountered")
        for key, value in optim_metrics.items():
            metrics[f"optim/{key}"] = value.item()
        self.cur_train_loss = ce_batch_loss.item()
        self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
        metrics["train/CrossEntropyLoss"] = self.cur_train_loss
        metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
        if z_batch_loss is not None:
            metrics["train/ZLoss"] = z_batch_loss.item()

        # Maybe collect post-step optimizer-specific metrics.
        if should_log_optim_metrics_this_step:
            optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
            for key, value in optim_metrics.items():
                metrics[f"optim/{key}"] = value.item()

        return metrics

如果定义了 indices_file,并且批次中包含 “index” 键,则将批次中的索引写入文件。这通常用于记录数据在训练中的顺序。

将优化器中的梯度清零,set_to_none=True 是为了在清零时将张量设为 None,节省内存。将批次中的张量移动到指定的设备(通常是 GPU)上。调用 train_batch 方法执行前向传播和反向传播,并得到交叉熵损失 ce_batch_loss 和辅助损失 z_batch_loss.如果 reduce_global_lossTrue,则对全局损失进行汇总。这是为了在分布式训练中将损失从不同设备上汇总到主设备上。

裁剪梯度并收集参数、梯度、优化器指标。should_log_optim_metrics_this_step 控制是否在此步骤中记录优化器指标。梯度裁剪是一种用于防止梯度爆炸的技术,它确保梯度的范数(或某些情况下,梯度的每个元素)不超过一个预定义的阈值。这有助于稳定训练过程,特别是在深度神经网络中。在深度学习中,梯度爆炸指的是在反向传播过程中梯度变得非常大,可能导致数值不稳定性。为了避免这种情况,可以对梯度进行裁剪。裁剪的方式通常是通过计算梯度的范数,然后将其缩放到一个预定的最大阈值。

根据学习率调度器调整学习率和梯度裁剪的阈值:遍历优化器中所有参数组。使用调度器(scheduler)根据当前学习率、当前训练步数(self.scheduler_current)和最大训练步数(self.scheduler_max)获取新的学习率。这可以是一个学习率调度策略,例如学习率衰减或周期性调整。类似地,使用调度器获取新的梯度裁剪阈值。这样可以根据训练步数调整梯度裁剪的阈值,以确保在训练的不同阶段使用不同的梯度裁剪。group[“max_grad_norm_ratio”]是另一种梯度裁剪阈值的计算方式,可能是根据一些比例参数进行调整。self.optim.step():执行优化器的一步更新,根据调整后的学习率和梯度裁剪阈值来更新模型的参数。

优化器的参数分组:在优化器中对参数进行分组的方式通常取决于模型的结构以及训练的需求。以下是一些常见的分组策略:

  1. 全局学习率: 所有参数共享一个学习率。这是最简单的方式,适用于模型中所有参数对学习率要求相似的情况。

  2. 层级学习率: 将模型的不同层级的参数分为不同的组,每个组有一个独立的学习率。例如,可以给模型底层的参数(如嵌入层)一个较小的学习率,而给顶层的参数(如全连接层)一个较大的学习率。

  3. 模块学习率: 如果模型由多个子模块组成,可以为每个子模块指定一个不同的学习率。这样,可以更灵活地调整模型中不同部分的训练速度。

  4. 自适应学习率: 使用自适应学习率算法,如Adam,其中每个参数都有自己的自适应学习率。这样,每个参数都可以根据其在训练中的表现来调整学习率。

  5. 特定任务学习率: 对于一些关键的任务相关的参数,可以为它们设置特定的学习率,以更加集中地优化这些参数。

检查训练过程中是否出现了NaN的损失。如果是,就抛出一个ValueError异常。NaN的损失通常意味着训练发生了异常,可能是由于梯度爆炸等问题

将优化器的一些指标记录到metrics中。这些指标可能包括学习率、梯度等信息。记录当前训练损失(交叉熵损失)。min_train_loss记录了训练过程中的最小损失值。Perplexity是语言模型中常用的性能指标,它是对交叉熵损失的指数运算。如果存在辅助损失(z_batch_loss不为None),则记录辅助损失。在训练步骤后,可能还会收集一些额外的优化器相关指标,这取决于具体的优化器实现。

self.optim.get_post_step_metrics(self.fsdp_model) 这行代码的作用是从优化器中获取训练步骤后的额外指标。这个过程依赖于具体的优化器实现,因为不同的优化器可能提供不同的额外信息。通常,这些额外指标可能包括学习率的变化、梯度信息等。

在深度学习中,训练步骤后的优化器指标有时对于监控和调试很有帮助。例如,可以通过这些指标来检查梯度的变化情况,确保学习率的调整符合预期,或者监控其他与训练优化过程相关的信息。

9.eval评估

    def eval_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
            ce_loss, logits = self.model_forward(batch, loss_reduction="none")
        return ce_loss.mean(dim=-1), logits

    def eval_step(self, batch: Dict[str, Any], evaluator: Evaluator) -> None:
        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)

        # Run forward pass.
        with torch.no_grad():  # NOTE: 'torch.inference_mode()' doesn't work with 'torch.compile()'.
            ce_loss, logits = self.eval_batch(batch)

        # Update metrics.
        evaluator.update_metrics(
            batch, ce_loss, logits
        )  # batch includes all keys that the downstream evaluation needs

        barrier()

eval_batch 函数用于对一个评估批次进行前向传播。在 torch.autocast 上下文中,使用模型前向传播计算交叉熵损失和logits。loss_reduction="none" 表示不进行损失的降维操作,保留每个样本的损失。最后返回的是经过平均处理的损失和logits。

eval_step 函数用于处理整个评估步骤。首先,将批次移动到正确的设备上。然后,使用 eval_batch 函数执行前向传播,得到损失和logits。由于这是评估阶段,使用 torch.no_grad() 上下文,以确保不计算梯度,从而减少内存使用和加速计算。

最后,通过 evaluator.update_metrics 更新评估指标。评估器(Evaluator)的任务是根据模型的输出和真实标签更新一系列评估指标,例如准确性、精确度、召回率等。evaluator.update_metrics 函数通常接受批次数据、损失和logits作为输入,并更新内部维护的评估指标。最后的 barrier() 用于同步不同的进程。

10.划分batch

    def split_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
        microbatch_size = self.cfg.device_train_microbatch_size
        batch_size = batch["input_ids"].shape[0]
        if batch_size <= microbatch_size:
            return [batch]
        else:
            micro_batches = {}
            for key, value in batch.items():
                if isinstance(value, torch.Tensor):
                    micro_batches[key] = value.split(microbatch_size, dim=0)
                elif isinstance(value, list):
                    micro_batches[key] = [
                        value[microbatch_size * i : microbatch_size * i + microbatch_size]
                        for i in range(math.ceil(batch_size / microbatch_size))
                    ]
                else:
                    raise ValueError(f"unexpected item in batch: '{key}={value}'")
            return [
                {key: value[i] for key, value in micro_batches.items()}  # type: ignore
                for i in range(len(micro_batches["input_ids"]))
            ]
  • split_batch 函数接受一个批次(batch),首先获取微批次(microbatch)的大小和整个批次的大小。
  • 如果批次的大小小于或等于微批次的大小,说明无需拆分,直接返回包含整个批次的列表。
  • 如果批次的大小大于微批次的大小,进入分批逻辑:
    • 使用字典 micro_batches 保存每个批次中各项的拆分结果。
    • 对于是 torch.Tensor 类型的项,使用 split 函数按指定维度(通常是维度0,即批次的维度)进行拆分。
    • 对于是 list 类型的项,使用列表切片进行拆分。
    • 对于其他类型的项,抛出异常。
  • 最后,通过列表推导式生成包含每个微批次的字典的列表。每个字典中包含了原始批次中每项的拆分结果。

这个拆分过程主要是为了在训练时实现对较大批次的模型并行训练,通过拆分成小批次,可以在多个设备上同时处理,提高训练效率。

拆分的效果类似于:

big_batch = { ‘input_ids’: torch.randn(5, 10), ‘other_feature’: torch.randn(5, 20), # 可能还有其他的键值对 }

被拆分为

Micro-batch 1:input_ids shape: torch.Size([2, 10])other_feature shape: torch.Size([2, 20])

Micro-batch 2:input_ids shape: torch.Size([2, 10])other_feature shape: torch.Size([2, 20])

Micro-batch 3:input_ids shape: torch.Size([1, 10])other_feature shape: torch.Size([1, 20])

11.logging设备状态

    def system_metrics(self) -> Dict[str, float]:
        metrics = {}
        if self.global_step < 3 or self.global_step % 10 == 0:
            peak_gpu_mb = peak_gpu_memory()
            if peak_gpu_mb is not None:
                metrics["System/Peak GPU Memory (MB)"] = peak_gpu_mb
        return metrics

    def log_metrics_to_console(self, prefix: str, metrics: Dict[str, float]):
        def format_float(value: float) -> str:
            if value < 0.0001:
                return str(value)  # scientific notation
            elif value > 1000:
                return f"{int(value):,d}"
            elif value > 100:
                return f"{value:.1f}"
            elif value > 10:
                return f"{value:.2f}"
            elif value > 1:
                return f"{value:.3f}"
            else:
                return f"{value:.4f}"

        log.info(
            f"{prefix}\n"
            + "\n".join(
                [
                    f"    {name}={format_float(value)}"
                    for name, value in metrics.items()
                    if not name.startswith("optim/")  # there's too many optimizer metrics
                ]
            )
        )
    def should_log_optim_metrics_this_step(self) -> bool:
        if self.cfg.wandb is None:
            # We only log optimizer-specific metrics to W&B, since there are usually too many metrics
            # to log to the console.
            return False
        optim_log_interval = self.cfg.optimizer.metrics_log_interval
        if optim_log_interval is None:
            optim_log_interval = self.cfg.wandb.log_interval
        else:
            optim_log_interval = max(optim_log_interval, self.cfg.wandb.log_interval)
        return self.global_step % optim_log_interval == 0

    def should_log_this_step(self) -> bool:
        if self.global_step % self.cfg.console_log_interval == 0:
            return True
        elif self.cfg.wandb is not None and self.global_step % self.cfg.wandb.log_interval == 0:
            return True
        else:
            return False

system_metricslog_metrics_to_console

  1. system_metrics 方法用于获取一些系统指标,主要是记录 GPU 内存的峰值使用情况。如果全局步数小于3或全局步数能够被10整除,就记录当前的 GPU 内存峰值。返回一个包含系统指标的字典。

  2. log_metrics_to_console 方法用于将指标记录到控制台。它接受一个前缀(prefix)和包含指标的字典。在这个方法中,通过遍历指标字典,将每个指标的名称和值按照一定的格式输出到控制台。

  • format_float 函数用于格式化浮点数,根据数值的大小,采用不同的显示方式。
  • 使用 log.info 打印带前缀的日志,其中包含格式化后的指标信息。

这两个方法一起完成了获取系统指标并将其输出到控制台的功能。

  1. should_log_optim_metrics_this_step 方法用于确定是否应该记录优化器相关的指标。如果配置中没有启用 W&B(Weights & Biases),则返回 False。否则,通过比较全局步数和配置中的日志间隔,判断当前步数是否是记录优化器指标的步数。
  • 如果配置中的 WandB 为 None,则不记录优化器相关指标。
  • 如果配置中的 WandB 日志间隔为 None,则使用配置中的优化器日志间隔。
  • 否则,使用较大的那个值。

返回值为全局步数是否符合记录优化器指标的条件。

  1. should_log_this_step 方法用于确定是否应该记录当前步数的指标。根据全局步数和配置中的控制台日志间隔、WandB 日志间隔来判断。
  • 如果全局步数能够被配置中的控制台日志间隔整除,或者 WandB 日志间隔整除,则返回 True,表示当前步数需要记录指标,否则返回 False

12.评估步骤

    def eval(self) -> Dict[str, Any]:
        # Zero gradients and set model to 'eval' mode.
        self.optim.zero_grad(set_to_none=True)
        self.fsdp_model.eval()

        eval_metrics = {}
        for evaluator in self.evaluators:
            log.info(f"Running evaluation for '{evaluator.label}'...")

            # Reset metrics.
            evaluator.reset_metrics()

            # Initialize data loader iterator.
            eval_batches = iter(evaluator.eval_loader)

            # Adjust how many batches to evaluate on.
            num_eval_batches = (
                evaluator.subset_num_batches
                if evaluator.subset_num_batches is not None
                else self.cfg.eval_subset_num_batches
            )
            if num_eval_batches > 0:
                num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
                eval_batches = islice(eval_batches, num_eval_batches)

            # Run model over batches.
            for eval_step, eval_batch in enumerate(eval_batches):
                self.eval_step(eval_batch, evaluator)

                # Log to console.
                if eval_step + 1 == num_eval_batches or (eval_step + 1) % self.cfg.console_log_interval == 0:
                    log.info(f"[eval_step={eval_step + 1}/{num_eval_batches}]")

            # Get final metrics.
            metrics = evaluator.compute_metrics()
            eval_metrics.update(metrics)
            self.log_metrics_to_console(f"{evaluator.label}", metrics)

            del eval_batches

        return eval_metrics
  1. 初始化设置:
  • 将梯度归零,将模型设置为“eval”模式。这会影响一些层的行为,如BatchNormDropout在评估模式下表现不同。
  • 获取用于评估的数据加载器和评估器。
  1. **循环遍历每个评估器:**对于每个评估器执行以下步骤:遍历所有评估器,其中self.evaluators是评估器的列表。打印当前评估器的标签。调用 reset_metrics 重置评估器的指标。创建数据加载器的迭代器,用于遍历评估数据。调整要评估的批次数量,如果指定了子集批次数量,则使用该值;否则,使用配置文件中指定的评估子集批次数量。

遍历模型评估的批次:

对每个批次执行 eval_step 方法,该方法更新评估器的指标。

如果达到指定的评估批次数量或控制台日志间隔,则打印日志。

计算最终指标。

更新总的评估指标。

将指标记录到控制台。

  1. **返回总的评估指标:**将所有评估器的指标合并到一个字典中,并返回。

这个方法用于在模型上执行评估,并记录评估过程中的指标。在循环中,对于每个评估器,模型会在数据集的批次上运行,并收集相应的指标。这些指标包括模型性能的各个方面,例如损失、准确性等。最终,这些指标被记录并返回,供用户查看模型在评估数据上的表现。

13.中断训练

    def check_if_cancelled(self) -> Tuple[bool, int]:
        should_cancel = False
        cancel_reason: Optional[str] = None
        extra_steps = 0
        if get_global_rank() == 0:
            if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
                # First check if we've reached the training time limit.
                should_cancel = True
                cancel_reason = "time limit reached"
                extra_steps = self.cfg.extra_steps_after_cancel
            elif (
                self.cfg.early_stopping_factor is not None
                and self.global_step > self.cfg.scheduler.t_warmup
                and self.cur_train_loss > self.cfg.early_stopping_factor * self.min_train_loss
            ):
                # Next check if early stopping loss criteria is met.
                should_cancel = True
                cancel_reason = "early stopping from loss increase"
            elif wandb.run is not None and (api_key := os.environ.get("WANDB_API_KEY")) is not None:
                # Finally, check if someone canceled the run from W&B by adding the 'cancel' / 'canceled' tag..
                # We won't see it in the run object. So we have to use the import/export API to check.
                from requests.exceptions import RequestException

                try:
                    api = wandb.Api(api_key=api_key)
                    run = api.run(wandb.run.path)
                    for tag in run.tags or []:
                        if tag.lower() in {"cancel", "canceled", "cancelled"}:
                            should_cancel = True
                            cancel_reason = "Weights & Biases tag"
                            extra_steps = self.cfg.extra_steps_after_cancel
                            break
                except RequestException:
                    pass

        run_canceled = synchronize_flag(should_cancel, self.device)
        if run_canceled and cancel_reason is not None:
            extra_steps = synchronize_value(extra_steps, self.device)
            if extra_steps > 0:
                log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
            else:
                log.warning(f"Run canceled due to {cancel_reason}")

        return run_canceled, extra_steps

用一些变量记录是否应该取消训练、取消的原因以及额外的训练步数。

只有全局排名为0的进程执行以下代码块,这是为了确保只有一个进程执行取消检查。

如果设置了训练时间限制且当前时间减去训练开始时间大于等于时间限制,则设置 should_cancel 为True,取消原因为“达到时间限制”,并设置额外步数。

否则,如果设置了早停因子,且全局步数大于预热步数,且当前训练损失大于早停因子乘以最小训练损失,则设置 should_cancel 为True,取消原因为“由于损失增加而早停”。

最后,检查是否有人通过在Weights & Biases中添加 ‘cancel’ / ‘canceled’ 标签取消了运行。使用W&B的API来检查标签是否存在。如果存在,则设置 should_cancel 为True,取消原因为“Weights & Biases标签”,并设置额外步数。

使用 synchronize_flag 同步 should_cancel 到所有设备。

如果运行被取消且有取消原因,则同步额外的步数到所有设备。

如果额外的步数大于0,则打印警告消息,表示由于某个原因取消了运行,并将在额外步数后停止。否则,只是打印取消原因。

最后返回是否取消了运行以及额外的步数。

额外的步数extra_steps_after_cancel是一种在检测到需要取消训练时,为模型提供额外优雅停止的机制。在训练过程中,有时候即使检测到取消条件,也希望模型继续训练一段时间,而不是立即停止。这可能是因为在某些取消条件下,模型可能尚未达到一个在训练任务中有意义的状态。提供额外的步数使得模型在被取消之前有更多的机会进行参数调整,收敛到一个更好的状态。这样做的目的是为了防止突然中断训练过程可能导致的不稳定性或不完整的模型状态。例如,在训练的最后阶段,模型可能正在执行一些收敛操作,如果突然取消可能导致模型在一个不理想的状态被保存。通过提供额外的步数,可以确保模型有时间完成当前的优化步骤,然后再停止训练。

14.主要训练part1准备工作

    def fit(self):
        self._start_time = time.time()

        if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
            eval_metrics = self.eval()
            if wandb.run is not None:
                wandb.log(eval_metrics, step=self.global_step)

        # Set model to 'train' mode.
        self.fsdp_model.train()

        # Initialize monitors.
        assert self.cfg.device_train_batch_size is not None
        speed_monitor = SpeedMonitor(self.cfg.speed_monitor)
        lr_monitor = LRMonitor(self.optim)

        # Log system metrics at the start of training.
        sys_metrics = self.system_metrics()
        if sys_metrics:
            self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
            if wandb.run is not None:
                wandb.log(sys_metrics, step=0)

        # Python Profiler stuff
        if self.cfg.python_profiling:
            python_profiler = cProfile.Profile()
        else:
            python_profiler = None

        # PyTorch Profiler stuff
        if self.cfg.torch_profiling and get_global_rank() == 0:
            from torch.profiler import schedule

            profiling_schedule = schedule(wait=1, warmup=5, active=3)

            def on_trace_ready(p):
                profiler_output_dir = Path(self.cfg.save_folder) / "profiler"
                profiler_output_dir.mkdir(exist_ok=True)

                output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32)
                log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}")
                output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
                log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")

                p.export_chrome_trace(
                    str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
                )
                if self.cfg.remote_save_folder is not None:
                    upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
                    log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
                    upload(trace_path, f"{upload_folder}/{trace_path.name}")

            from torch.profiler import ProfilerActivity

            torch_profiler = torch.profiler.profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=False,
                profile_memory=False,
                with_stack=True,
                schedule=profiling_schedule,
                on_trace_ready=on_trace_ready,
            )
            del profiling_schedule
        else:
            import contextlib

            torch_profiler = contextlib.nullcontext()
  • 1.记录时间

  • 2.检查是否有预训练模型的路径、已经有一些全局步数,并且配置要求在加载时进行评估。如果是,则调用 self.eval() 进行一次评估,并使用 WandB 记录评估指标。再将模型设置为训练模式

  • 3.确保配置中定义了训练时的批次大小。

  • 4.初始化速度监视器。初始化学习率监视器。获取系统指标,例如 GPU 内存峰值等等,如果有,将系统指标记录到控制台,如果正在使用 WandB,也记录到wandb,step=0 表示这是训练的初始步骤。

  • 5.选择性启用 Python Profiler。Python Profiler(性能分析器)是用于分析程序运行时性能的工具。它可以帮助你识别代码中的瓶颈,找出哪些部分消耗了最多的时间,以及在哪里可能需要进行优化。Python标准库中包含了一个名为cProfile的性能分析器模块,它提供了对程序运行时的详细性能统计信息。你可以使用cProfile来测量函数调用的执行时间、函数调用的次数等信息。

  • 6.如果启用了PyTorch Profiler(PyTorch框架内置的性能分析工具,帮助深度学习开发者分析和优化他们的模型性能),设置 PyTorch Profiler 的调度计划。定义一个回调函数 on_trace_ready(p),该函数会在跟踪准备就绪时调用。创建 PyTorch Profiler 对象 torch_profiler,并设置相关参数。如果未启用 PyTorch Profiler,就创建一个上下文管理器,什么都不做。

  • 7.on trace ready:profiler_output_dir = Path(self.cfg.save_folder) / "profiler":定义保存追踪文件的目录,这个目录通常是训练中的保存目录下的 “profiler” 子目录。profiler_output_dir.mkdir(exist_ok=True):确保目录存在,如果不存在则创建。output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=32):通过调用 p.key_averages() 获取 PyTorch Profiler 记录的关键指标,然后使用 table() 方法生成一个表格,按照 CUDA 时间总量排序,最多显示 32 行。log.info(f"Profile by total GPU time at step {p.step_num}:\n{output}"):将 GPU 时间排序的表格信息记录到日志中。类似地,通过调用 p.key_averages()table() 方法,获取并记录了 CPU 时间排序的表格信息。p.export_chrome_trace(str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))):将追踪信息导出为 Chrome Trace 格式的文件,并保存在以步数命名的文件中。这个文件通常是一个压缩的 JSON 文件。if self.cfg.remote_save_folder is not None::如果配置中指定了远程保存目录。upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler":定义远程上传目录。log.info(f"Tracing complete, uploading results to '{upload_folder}'..."):记录追踪完成,并开始上传。upload(trace_path, f"{upload_folder}/{trace_path.name}"):调用上传函数将追踪文件上传到远程保存目录。这个回调函数主要用于在追踪准备就绪时,生成并保存追踪信息,并可选择将其上传到远程保存目录。这对于性能分析和调优非常有用。

14.主要训练part2训练

# Train.
        first_batch: bool = True
        cancel_initiated: bool = False
        stop_at: Optional[int] = self.cfg.stop_at
        save_checkpoints: bool = True

        with torch_profiler as p:
            for epoch in range(self.epoch or 0, self.max_epochs):
                for batch in self.train_loader:
                    # Bookkeeping.
                    # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
                    # batches see the same number of tokens, which should be the case for language model pre-training
                    # (at least when drop_last=True).
                    # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
                    # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
                    # fail loudly.
                    batch_size, seq_len = batch["input_ids"].shape
                    assert seq_len == self.cfg.model.max_sequence_length
                    assert batch_size == self.cfg.device_train_batch_size
                    global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
                    self.global_step += 1
                    self.global_train_examples_seen_this_epoch += global_batch_size
                    self.global_train_tokens_seen += global_batch_size * seq_len
                    speed_monitor.batch_start(
                        self.global_train_tokens_seen,
                        batch_size * seq_len,  # num tokens in batch for this device
                        # We start monitoring speed after the first batch since the first
                        # batch might be an outlier due to compiling and other initialization overhead.
                        record=not first_batch,
                    )

                    should_log_this_step = self.should_log_this_step()

                    # Run train step on batch.
                    metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)

                    # Maybe collect other metrics.
                    if should_log_this_step:
                        # Speed metrics.
                        metrics.update(speed_monitor.check())
                        # System metrics.
                        metrics.update(self.system_metrics())
                        # Learning rate metrics.
                        metrics.update(lr_monitor.check())

                    # Log metrics to console.
                    if self.global_step % self.cfg.console_log_interval == 0:
                        self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

                    # Log metrics to W&B.
                    if (
                        wandb.run is not None
                        and self.cfg.wandb is not None
                        and self.global_step % self.cfg.wandb.log_interval == 0
                    ):
                        wandb.log(metrics, step=self.global_step)

                    # Check if/when run should be canceled.
                    if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                        cancel_initiated, extra_steps = self.check_if_cancelled()
                        if cancel_initiated:
                            stop_at = (
                                self.global_step + extra_steps
                                if stop_at is None
                                else min(self.global_step + extra_steps, stop_at)
                            )

                    # Maybe save sharded checkpoint.
                    if save_checkpoints and (
                        cancel_initiated
                        or (
                            self.global_step % self.cfg.save_interval == 0
                            and self.cfg.save_num_checkpoints_to_keep != 0
                        )
                    ):
                        log.info("Saving checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Remove any ephemeral checkpoints.
                        while self.ephemeral_checkpoints:
                            self.remove_ephemeral_checkpoint()

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                        # If the run was just canceled this will be the final checkpoint.
                        if cancel_initiated:
                            save_checkpoints = False
                    elif (
                        self.cfg.save_interval_ephemeral is not None
                        and self.global_step % self.cfg.save_interval_ephemeral == 0
                    ):
                        log.info("Saving ephemeral checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe save unsharded checkpoint.
                    if (
                        save_checkpoints
                        and self.cfg.save_interval_unsharded is not None
                        and self.global_step % self.cfg.save_interval_unsharded == 0
                        and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
                    ):
                        log.info("Saving unsharded checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
                        log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe run evaluations.
                    if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
                        eval_metrics = self.eval()

                        # Log metrics to W&B.
                        if wandb.run is not None:
                            wandb.log(eval_metrics, step=self.global_step)

                        # Reset speed monitor so that we don't count the time taken to run evaluations.
                        speed_monitor.reset()

                        # Reset model to 'train' mode.
                        self.fsdp_model.train()

                    # End of batch.
                    first_batch = False
                    if p is not None:
                        p.step()

                    if stop_at is not None and self.global_step >= stop_at:
                        break

                    # Python Profiler stuff
                    # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
                    if python_profiler is not None:
                        if self.global_step == 5:
                            python_profiler.enable()
                        elif self.global_step == 8:
                            python_profiler.disable()
                            python_profiler.print_stats(sort=SortKey.CUMULATIVE)
                            python_profiler = None
                else:
                    log.info("Training epoch complete")
                    self.epoch = epoch + 1
                    self.global_train_examples_seen_this_epoch = 0
                    if self.epoch < self.max_epochs:
                        self.dataset.reshuffle()
                    continue

                break

        # Save final checkpoint.
        if save_checkpoints:
            if self.cfg.save_interval_unsharded is not None:
                log.info("Saving final unsharded model checkpoint...")
                checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
                log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
            elif self.cfg.save_num_checkpoints_to_keep != 0:
                log.info("Saving final checkpoint...")
                checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
                log.info(f"Checkpoint saved to {checkpoint_path}")
  1. first_batch: bool = True:用于标记是否是训练循环中的第一个批次。在循环开始时,它被设置为 True,以便能够区分第一个批次和后续批次。在第一个批次之后,它将被设置为 False

  2. cancel_initiated: bool = False:用于标记是否已启动取消训练的过程。在训练循环中的某一点,可能会检测到需要取消训练的条件(例如达到时间限制),这时将设置 cancel_initiatedTrue

  3. stop_at: Optional[int] = self.cfg.stop_at:用于设置一个特定的步数,当全局步数达到这个值时,训练将停止。这个值可以从配置中提供,如果未提供,则默认为 None,表示没有设置停止步数。

  4. save_checkpoints: bool = True:用于标记是否保存检查点。在训练中的不同时间点,可能会设置为 False,以避免保存额外的检查点。例如,当训练取消时,可能会停止保存检查点。

  5. 使用 torch_profiler 上下文管理器,它包装了 PyTorch 的性能分析器。这意味着在这个代码块中,性能分析器将会被启用。

  6. 遍历训练的多个 epoch

  7. 在每个 epoch 中,遍历训练数据加载器 (self.train_loader),

  8. 数据的一些记录,确保每个批次的数据满足预期的形状。它包括对输入序列长度和批次大小的断言。

  9. 每个设备上的批次大小乘以参与训练的设备总数,得到了全局的批次大小,因为可能会有多个设备,get_world_size() 返回分布式训练中训练的总设备数。进行了一些全局步数和样本统计的更新,跟踪训练的进度

  10. speed_monitor.batch_start 用于记录训练速度的信息,包括全局已经看到的标记数和当前批次的标记数。

  11. 检查是否应该在当前步骤记录日志

  12. 调用 self.train_step 方法,对当前批次进行训练步骤。 reduce_global_loss 参数表示是否在全局损失上进行减少(用于记录日志等)。返回的 metrics 包含了训练步骤的度量结果。

  13. 如果 should_log_this_stepTrue,即满足记录日志的条件:

  • 更新速度监控、系统指标和学习率指标,这些指标可能包括训练速度、GPU 内存使用等。
  • 将这些指标记录到控制台,通过 self.log_metrics_to_console 方法。
  • 如果使用了 Weights & Biases(W&B),并且配置中启用了 W&B 记录,将这些指标记录到 W&B。
  1. 在训练过程中,每隔一定步数(由 self.cfg.canceled_check_interval 控制),检查训练是否应该被取消。这通过调用 self.check_if_cancelled 方法来实现。如果发现应该取消训练,设置 cancel_initiatedTrue,并计算 extra_steps,表示在取消之前需要执行的额外步骤数。然后更新 stop_at 变量,表示在哪个全局步数停止训练。这个值可能是当前步数加上额外步骤数,或者是之前计划停止的步数,取两者的最小值。

  2. 首先,它检查是否满足保存检查点的条件。保存检查点的条件有两个:

  • cancel_initiatedTrue,即训练被取消。
  • 或者当前全局步数 self.global_step 满足 self.cfg.save_interval 的倍数,并且配置中 self.cfg.save_num_checkpoints_to_keep 不为零。
  1. 如果满足保存检查点的条件,它会执行以下步骤:
  • 打印日志,表示正在保存检查点。
  • 调用 self.save_checkpoint 方法保存检查点。这个方法返回检查点的路径和其他信息。
  • 打印保存的检查点路径。
  • 移除任何临时的检查点。这是通过调用 self.remove_ephemeral_checkpoint 方法实现的。
  • 重置速度监控,以便不计算保存检查点所需的时间。
  • 如果训练已经被取消,将 save_checkpoints 设置为 False,表示这将是最后一个检查点。
  1. 如果不满足保存检查点的条件,但配置中指定了 self.cfg.save_interval_ephemeral,并且当前全局步数是 self.cfg.save_interval_ephemeral 的倍数,那么也会保存一个临时的检查点。这个临时检查点是通过调用 self.save_checkpoint 方法,并指定类型为 CheckpointType.sharded_ephemeral 来实现的。同样,这会打印保存的检查点路径,并重置速度监控。

  2. 非分片检查点的保存和定期运行评估同上

  3. 结束批次(End of batch):

  • first_batch 设置为 False,表示不再是第一个批次。
  • 如果使用了 PyTorch Profiler (p is not None),则执行一步 p.step()。Profiler 的作用是收集性能数据,这里确保在获取下一批次之前捕捉到获取当前批次的工作。
  1. 检查是否达到停止条件:
  • 如果设置了 stop_at,表示训练需要在特定的全局步数停止,那么会检查当前的全局步数是否达到或超过了设定的停止步数。如果满足条件,就会跳出循环,结束训练。
  1. Python Profiler:
  • 在此阶段执行 Python Profiler 的操作。
  • python_profiler 不为 None 且全局步数为第5步时,启用 Python Profiler。
  • 当全局步数为第8步时,禁用 Python Profiler,打印性能统计数据,并将 python_profiler 设置为 None,表示不再使用 Python Profiler。
  • 这是为了在训练的早期阶段获取关于性能的一些信息,然后在之后的步骤中停用以提高训练速度。
  1. 训练 epoch 结束:

for 循环和 else 语句可以结合使用。else 语句在循环迭代完毕后执行,但只有在循环没有被 break 语句中断的情况下才会执行。

  • 如果内层循环 for batch in self.train_loader 完成,表示一个训练 epoch 结束。
  • 打印日志,指示训练 epoch 完成。
  • 更新当前 epoch 和全局训练样本数。
  • 如果当前 epoch 小于最大 epoch 数,重新洗牌数据集以进行下一个 epoch。
  1. 保存最终检查点:
  • 如果允许保存检查点 (save_checkpoints),并且设置了保存非分片检查点的间隔 (self.cfg.save_interval_unsharded),则保存最终的非分片模型检查点。
  • 否则,如果设置了要保留的检查点数量 (self.cfg.save_num_checkpoints_to_keep),则保存最终的分片模型检查点。

15.开始和关闭

    def close(self, exit_code: int = 0) -> None:
        if self.indices_file is not None:
            self.indices_file.flush()
            self.indices_file.close()
        if wandb.run is not None:
            wandb.finish(exit_code=exit_code, quiet=True)

    def __enter__(self) -> Trainer:
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        del exc_val, exc_tb
        self.close(0 if exc_type is None else 1)
  • self.indices_file.flush():调用文件对象的 flush 方法,将缓冲区的数据写入文件。这是为了确保文件中的所有数据都被写入磁盘。

  • self.indices_file.close():调用文件对象的 close 方法,关闭文件。

  • 调用 WandB 的 finish 方法,结束 WandB 的运行。exit_code 参数表示退出码,quiet=True 表示在退出时不显示输出。

  • 实现了上下文管理器(Context Manager)的类,使用了 __enter____exit__ 方法。上下文管理器常用于管理资源的获取和释放,确保在代码块执行前后资源得到正确地分配和释放。简单来讲就是__enter__能利用with创建类,自动进行一些步骤,这里直接返回类,并在程序运行至管理器之外时,调用__exit__ 方法里的内容,帮助清理一些内存。

    • __enter__(self) -> Trainer::这个方法在进入 with 代码块时被调用,返回一个上下文管理器对象,通常是 self,即当前对象。在这个具体的实现中,它直接返回 self,表示在进入 with 代码块时,将当前对象作为上下文管理器。

    • __exit__(self, exc_type, exc_val, exc_tb) -> None::这个方法在离开 with 代码块时被调用,负责释放资源。参数 exc_typeexc_valexc_tb 分别代表异常的类型、值和回溯信息。在这个具体的实现中,它删除了 exc_valexc_tb,然后调用 self.close(0 if exc_type is None else 1)。这表示如果没有发生异常 (exc_type is None),则以退出码 0 调用 self.close(),否则以退出码 1 调用 self.close()

    综合来说,这个类被设计成可用于上下文管理器的语法,通过 with 语句使用时,确保在代码块执行前后资源得到正确地分配和释放。在离开 with 代码块时,__exit__ 方法会被调用,它调用了 self.close() 方法,来释放一些资源。

ending:

感悟还是很多的,基本上定义了一个训练工程所有需要的步骤,以及一些所需的训练器配件,而这些配件的实现则是更进一步需要探索的东西

  • 19
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值