accelerate一些类和函数说明一

ProjectConfiguration类

from dataclasses import dataclass, field

@dataclass
class ProjectConfiguration:
    """
    Configuration for the Accelerator object based on inner-project needs.
    """

    project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."})
    logging_dir: str = field(
        default=None,
        metadata={
            "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`."
        },
    )
    automatic_checkpoint_naming: bool = field(
        default=False,
        metadata={"help": "Whether saved states should be automatically iteratively named."},
    )

    total_limit: int = field(
        default=None,
        metadata={"help": "The maximum number of total saved states to keep."},
    )

    iteration: int = field(
        default=0,
        metadata={"help": "The current save iteration."},
    )

    save_on_each_node: bool = field(
        default=False,
        metadata={
            "help": (
                "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
                " only on the main one"
            )
        },
    )

    def set_directories(self, project_dir: str = None):
        "Sets `self.project_dir` and `self.logging_dir` to the appropriate values."
        self.project_dir = project_dir
        if self.logging_dir is None:
            self.logging_dir = project_dir

    def __post_init__(self):
        self.set_directories(self.project_dir)

ProjectConfiguration 类是一个使用 Python 的 dataclass 模块定义的数据类,用于配置与项目相关的各种参数,特别是针对使用加速器进行分布式训练的场景。下面是对这个类的详细解释:

类的属性

  1. project_dir (str): 用于存储数据的目录路径。如果没有指定,默认为 None

    • 用于设定项目的根目录,可以在该目录下存储数据、模型或其他项目相关的文件。
  2. logging_dir (str): 用于存储日志的目录路径。如果没有指定,默认使用 project_dir

    • 用于存储本地日志,帮助跟踪训练过程中的信息。如果不指定,会使用 project_dir 的路径。
  3. automatic_checkpoint_naming (bool): 是否自动为保存的状态命名。默认为 False

    • 如果设为 True,保存的模型和检查点会自动使用递增的命名方式,这样可以避免覆盖之前的检查点。
  4. total_limit (int): 要保留的最大检查点数。如果没有指定,默认为 None

    • 用于限制保存的检查点数量,以防止磁盘空间被过多的检查点占用。
  5. iteration (int): 当前的保存迭代次数。默认为 0

    • 这个属性可以用来跟踪训练过程中的保存次数,例如每次保存后这个计数器都会增加。
  6. save_on_each_node (bool): 在多节点分布式训练中,是否在每个节点上保存模型和检查点。默认为 False

    • 如果设为 True,每个节点都会保存模型和检查点,否则只在主节点上保存。这对于一些分布式训练场景可能非常重要。

方法

  1. set_directories(self, project_dir: str = None):

    • 该方法用于设置 project_dirlogging_dir 的值。如果 logging_dir 未设置,则默认使用 project_dir 的值。
  2. __post_init__(self):

    • dataclass 提供的特殊方法,在对象初始化后自动调用。此方法确保在初始化时调用 set_directories() 方法,以便在创建对象时根据提供的 project_dir 设置路径。

总结

这个类的主要作用是提供一个结构化的方式来定义和管理项目的配置参数。通过使用 dataclass,它使得代码更加简洁,同时使用了默认值和元数据来提供对每个参数的详细帮助信息。该类适用于管理训练过程中的目录设置、日志记录、检查点保存策略等配置项,特别是对于多节点分布式训练的场景。

PrecisionType枚举

class PrecisionType(BaseEnum):
    """Represents a type of precision used on floating point values

    Values:

        - **NO** -- using full precision (FP32)
        - **FP16** -- using half precision
        - **BF16** -- using brain floating point precision
    """

    NO = "no"
    FP8 = "fp8"
    FP16 = "fp16"
    BF16 = "bf16"

class BaseEnum(enum.Enum, metaclass=EnumWithContains):
    "An enum class that can get the value of an item with `str(Enum.key)`"

    def __str__(self):
        return self.value

    @classmethod
    def list(cls):
        "Method to list all the possible items in `cls`"
        return list(map(str, cls))

class EnumWithContains(enum.EnumMeta):
    "A metaclass that adds the ability to check if `self` contains an item with the `in` operator"

    def __contains__(cls, item):
        try:
            cls(item)
        except ValueError:
            return False
        return True

EnumWithContains 通过实现 contains 方法,为枚举类提供了使用 in 运算符的能力,从而简化了检查某个值是否是枚举成员的操作。

kwargs_handlers参数

self.ddp_handler: 可能用于处理分布式数据并行(DDP, Distributed Data Parallel)相关的配置或操作。
self.scaler_handler: 可能用于处理混合精度训练中的缩放操作,管理 torch.cuda.amp.GradScaler 或类似的功能。
self.init_handler: 可能用于管理初始化操作或配置,确保模型或环境正确初始化。
self.fp8_recipe_handler: 可能用于处理 FP8(浮点8位)计算的配方或配置,这在一些特定的硬件上有助于提高计算效率。
self.autocast_handler: 可能用于管理自动混合精度(autocasting)的设置,以便在需要时自动切换计算精度。
self.profile_handler: 可能用于性能分析和优化,管理与模型或训练过程性能分析相关的配置或操作。
self.has_lomo_optimizer: 一个布尔值,可能用于指示当前是否在使用 LoMo 优化器。

AcceleratorState类

self.__dict__
self.state
在这里插入图片描述
在这里插入图片描述
关于优化器
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Deepspeed的ZeRO3调试

config yaml文件的配置
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
deepspeed_with_config_supoort.py
在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值