文章目录
ProjectConfiguration类
from dataclasses import dataclass, field
@dataclass
class ProjectConfiguration:
"""
Configuration for the Accelerator object based on inner-project needs.
"""
project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."})
logging_dir: str = field(
default=None,
metadata={
"help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`."
},
)
automatic_checkpoint_naming: bool = field(
default=False,
metadata={"help": "Whether saved states should be automatically iteratively named."},
)
total_limit: int = field(
default=None,
metadata={"help": "The maximum number of total saved states to keep."},
)
iteration: int = field(
default=0,
metadata={"help": "The current save iteration."},
)
save_on_each_node: bool = field(
default=False,
metadata={
"help": (
"When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
" only on the main one"
)
},
)
def set_directories(self, project_dir: str = None):
"Sets `self.project_dir` and `self.logging_dir` to the appropriate values."
self.project_dir = project_dir
if self.logging_dir is None:
self.logging_dir = project_dir
def __post_init__(self):
self.set_directories(self.project_dir)
ProjectConfiguration
类是一个使用 Python 的 dataclass
模块定义的数据类,用于配置与项目相关的各种参数,特别是针对使用加速器进行分布式训练的场景。下面是对这个类的详细解释:
类的属性
-
project_dir
(str
): 用于存储数据的目录路径。如果没有指定,默认为None
。- 用于设定项目的根目录,可以在该目录下存储数据、模型或其他项目相关的文件。
-
logging_dir
(str
): 用于存储日志的目录路径。如果没有指定,默认使用project_dir
。- 用于存储本地日志,帮助跟踪训练过程中的信息。如果不指定,会使用
project_dir
的路径。
- 用于存储本地日志,帮助跟踪训练过程中的信息。如果不指定,会使用
-
automatic_checkpoint_naming
(bool
): 是否自动为保存的状态命名。默认为False
。- 如果设为
True
,保存的模型和检查点会自动使用递增的命名方式,这样可以避免覆盖之前的检查点。
- 如果设为
-
total_limit
(int
): 要保留的最大检查点数。如果没有指定,默认为None
。- 用于限制保存的检查点数量,以防止磁盘空间被过多的检查点占用。
-
iteration
(int
): 当前的保存迭代次数。默认为0
。- 这个属性可以用来跟踪训练过程中的保存次数,例如每次保存后这个计数器都会增加。
-
save_on_each_node
(bool
): 在多节点分布式训练中,是否在每个节点上保存模型和检查点。默认为False
。- 如果设为
True
,每个节点都会保存模型和检查点,否则只在主节点上保存。这对于一些分布式训练场景可能非常重要。
- 如果设为
方法
-
set_directories(self, project_dir: str = None)
:- 该方法用于设置
project_dir
和logging_dir
的值。如果logging_dir
未设置,则默认使用project_dir
的值。
- 该方法用于设置
-
__post_init__(self)
:dataclass
提供的特殊方法,在对象初始化后自动调用。此方法确保在初始化时调用set_directories()
方法,以便在创建对象时根据提供的project_dir
设置路径。
总结
这个类的主要作用是提供一个结构化的方式来定义和管理项目的配置参数。通过使用 dataclass
,它使得代码更加简洁,同时使用了默认值和元数据来提供对每个参数的详细帮助信息。该类适用于管理训练过程中的目录设置、日志记录、检查点保存策略等配置项,特别是对于多节点分布式训练的场景。
PrecisionType枚举
class PrecisionType(BaseEnum):
"""Represents a type of precision used on floating point values
Values:
- **NO** -- using full precision (FP32)
- **FP16** -- using half precision
- **BF16** -- using brain floating point precision
"""
NO = "no"
FP8 = "fp8"
FP16 = "fp16"
BF16 = "bf16"
class BaseEnum(enum.Enum, metaclass=EnumWithContains):
"An enum class that can get the value of an item with `str(Enum.key)`"
def __str__(self):
return self.value
@classmethod
def list(cls):
"Method to list all the possible items in `cls`"
return list(map(str, cls))
class EnumWithContains(enum.EnumMeta):
"A metaclass that adds the ability to check if `self` contains an item with the `in` operator"
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
EnumWithContains 通过实现 contains 方法,为枚举类提供了使用 in 运算符的能力,从而简化了检查某个值是否是枚举成员的操作。
kwargs_handlers参数
self.ddp_handler: 可能用于处理分布式数据并行(DDP, Distributed Data Parallel)相关的配置或操作。
self.scaler_handler: 可能用于处理混合精度训练中的缩放操作,管理 torch.cuda.amp.GradScaler 或类似的功能。
self.init_handler: 可能用于管理初始化操作或配置,确保模型或环境正确初始化。
self.fp8_recipe_handler: 可能用于处理 FP8(浮点8位)计算的配方或配置,这在一些特定的硬件上有助于提高计算效率。
self.autocast_handler: 可能用于管理自动混合精度(autocasting)的设置,以便在需要时自动切换计算精度。
self.profile_handler: 可能用于性能分析和优化,管理与模型或训练过程性能分析相关的配置或操作。
self.has_lomo_optimizer: 一个布尔值,可能用于指示当前是否在使用 LoMo 优化器。
AcceleratorState类
关于优化器
Deepspeed的ZeRO3调试
config yaml文件的配置
deepspeed_with_config_supoort.py