Keras深度学习框架第十六讲:使用Keras 3进行分布式训练

70 篇文章 0 订阅
48 篇文章 0 订阅

1、绪论

Keras分布式API是一个新接口,旨在促进在各种后端(如JAX、TensorFlow和PyTorch)上进行分布式深度学习。这个强大的API引入了一系列工具,支持数据和模型并行性,允许在多个加速器和主机上高效扩展深度学习模型。无论是利用GPU还是TPU的强大功能,该API都提供了一种简化的方法来初始化分布式环境、定义设备网格,并在计算资源之间组织张量的布局。通过像DataParallel和ModelParallel这样的类,它抽象了并行计算中的复杂性,使开发人员更容易加速他们的机器学习工作流程。

Keras分布式API具备广泛的用途:

  • 多后端支持:Keras原本就支持多种深度学习后端,如TensorFlow、Theano和CNTK。通过分布式API,用户可以根据自己的需求选择合适的后端,并享受到统一的接口和便捷性。
  • 动态后端选择:Keras 3.0能够动态为模型提供最佳性能的后端,而无需更改代码,保证以最高效率运行。这意味着可以将Keras模型与PyTorch生态的包、TensorFlow中的部署工具或生产工具,以及JAX大规模TPU训练基础设施一起使用,获得机器学习世界所提供的一切。
  • 数据并行性:通过数据并行性,可以在多个设备上同时处理数据批次,从而加速模型的训练过程。这对于大规模数据集尤其有用。
  • 模型并行性:模型并行性允许将模型的各个部分分布到不同的设备上,从而加速模型的训练和推理。这对于复杂的深度学习模型特别有用。
  • 预训练模型和模型迁移:Keras提供了许多预训练的深度学习模型,如VGG、ResNet、Inception等。这些模型可以直接在用户的任务上进行微调或迁移学习,从而加快模型的训练效率。通过分布式API,这些预训练模型可以更容易地在多个设备上并行训练。
  • 简化的开发流程:Keras分布式API通过抽象并行计算的复杂性,使得开发人员更容易理解和实现分布式深度学习。它提供了直观的接口和工具,使得构建、调试和部署深度学习模型变得更加容易。

总的来说,Keras分布式API为深度学习开发者提供了一个强大的工具集,使得他们能够更轻松地利用多设备、多后端的优势来加速深度学习模型的训练和推理。

2、Keras 3分布式训练API的工作机制

2.1 工作原理

Keras分布式API提供了一个全局编程模型,允许开发者编写在全局上下文(就像在处理单个设备一样)中操作张量的应用程序,同时自动管理多个设备之间的分布。该API利用底层框架(如JAX)通过称为单程序多数据(SPMD)扩展的过程,根据分片指令来分发程序和张量。

通过将应用程序与分片指令解耦,该API使得相同的应用程序可以在单个设备、多个设备甚至多个客户端上运行,同时保持其全局语义。

2.2 系统设置

import os

# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # For dataset input.

2.3DeviceMeshTensorLayout

在Keras分布式API中,keras.distribution.DeviceMesh 类代表了一个为分布式计算配置的计算设备集群。它与JAX的 jax.sharding.Mesh 和TensorFlow的 tf.dtensor.Mesh 中的类似概念相一致,用于将物理设备映射到一个逻辑网格结构。

然后,TensorLayout 类指定了张量如何在 DeviceMesh 上进行分布,详细说明了张量在对应于 DeviceMesh 中轴名称的指定轴上的分片。

# Retrieve the local available gpu devices.
devices = jax.devices("gpu")  # Assume it has 8 local GPUs.

# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)

# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
    axes=("data", None, None, None), device_mesh=mesh
)

2.4 分发

Keras中的Distribution类是一个基础抽象类,旨在开发自定义分发策略。它封装了在设备网格上分发模型的变量、输入数据和中间计算所需的核心逻辑。作为最终用户,程序员不需要直接与这个类交互,而是与其子类(如DataParallelModelParallel)进行交互。

2.5 并行数据DataParallel

Keras分布式API中的DataParallel类是为分布式训练中的数据并行策略设计的,其中模型权重在DeviceMesh中的所有设备上被复制,每个设备处理输入数据的一部分。

以下是这个类的一个使用示例。

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
    shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

# Set the global distribution.
keras.distribution.set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)

2.6ModelParallel和LayoutMap

当模型权重过大而无法适应单个加速器时,ModelParallel将非常有用。这种设置允许程序员将模型权重或激活张量分散到DeviceMesh上的所有设备上,并为大型模型启用水平扩展。

与DataParallel模型不同,后者会完全复制所有权重,ModelParallel下的权重布局通常需要一些定制以获得最佳性能。引入了LayoutMap的目的,是让程序员可以从全局角度为任何权重和中间张量指定TensorLayout。

LayoutMap是一个类似于字典的对象,它将字符串映射到TensorLayout实例。与普通的Python字典不同,它在检索值时将字符串键视为正则表达式。这个类允许程序员定义TensorLayout的命名模式,然后检索相应的TensorLayout实例。通常,用于查询的键是variable.path属性,它是变量的标识符。作为快捷方式,在插入值时,也允许使用轴名的元组或列表,并会将其转换为TensorLayout。

LayoutMap还可以选择性地包含一个DeviceMesh,以便在TensorLayout.device_mesh未设置时填充它。当使用键检索布局时,如果没有完全匹配的项,LayoutMap中的所有现有键都将被视为正则表达式,并再次与输入键进行匹配。如果有多个匹配项,将引发ValueError。如果没有找到匹配项,则返回None。

mesh_2d = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)

# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(
    mesh_2d, layout_map, batch_dim_name="data"
)

keras.distribution.set_distribution(model_parallel)

inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)

# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)

同样,通过调整网格的形状,可以轻松地更改网格结构以调整计算以更多地偏向数据并行或模型并行。这样做不需要对其他代码进行任何更改。

full_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(1, 8), axis_names=["data", "model"], devices=devices
)

3、总结

Keras分布式API中的ModelParallel类和LayoutMap提供了用于在多个设备上分发模型权重和激活张量的机制,支持大型模型的水平扩展。ModelParallel允许在DeviceMesh上的所有设备上分散模型权重,而LayoutMap则允许用户从全局角度为任何权重和中间张量指定TensorLayout。通过调整网格的形状,可以轻松地在数据并行和模型并行之间调整计算,而无需修改其他代码。这种灵活性使得在分布式环境中训练大型模型变得更加高效和灵活。

ModelParallel

在分布式深度学习训练中,当模型变得非常大,以至于单个计算设备(如GPU或TPU)无法容纳整个模型时,就需要使用模型并行(ModelParallel)策略。ModelParallel允许将模型的各个部分(通常是层或子图)分布在不同的计算设备上,每个设备负责计算模型的一个子集。

具体来说,在ModelParallel设置中,模型权重被分割并分配到DeviceMesh上的不同设备上。每个设备处理模型的不同部分,并与其他设备进行通信以完成前向和后向传播。通过这种方式,大型模型可以在多个设备上并行运行,从而加快训练速度。

ModelParallel的优点是能够处理超大型模型,并且可以在具有不同计算能力的设备上进行扩展。然而,它也需要额外的通信开销来同步不同设备上的模型权重和梯度。

** LayoutMap**

LayoutMap是Keras分布式API中的一个重要组件,它允许用户为模型的权重和中间张量指定TensorLayout。TensorLayout描述了张量在DeviceMesh上的分布方式,包括张量的分片和放置策略。

LayoutMap是一个类似于字典的对象,它将字符串键映射到TensorLayout实例。这些字符串键通常表示模型变量的名称或路径。用户可以通过LayoutMap为不同的变量指定不同的TensorLayout,以优化计算和内存使用。

LayoutMap的一个重要特性是,它的键可以被视为正则表达式。当使用某个键查询TensorLayout时,LayoutMap将尝试使用所有现有的键作为正则表达式与查询键进行匹配。如果找到多个匹配项,将引发错误;如果没有找到匹配项,则返回None。这种灵活性使得用户可以更容易地管理和应用TensorLayout。

ModelParallel和LayoutMap是Keras分布式API中用于支持大型模型分布式训练的重要组件。ModelParallel允许将模型的各个部分分布在多个设备上,而LayoutMap则允许用户为模型的权重和中间张量指定TensorLayout。这些功能使得在分布式环境中训练大型模型变得更加高效和灵活。

  • 9
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值