1、绪论
使用JAX进行多GPU分布式训练是一种高效的策略,特别适用于大规模深度学习模型的训练。JAX作为一个包含可组合函数变换的数值计算库,为深度学习研究提供了强大的工具和灵活性。
1.1 使用JAX进行分布式训练的优势
以下是关于使用JAX进行多GPU分布式训练的一些关键点和优势:
-
可组合函数变换:JAX的设计允许用户构建可组合的数值计算函数,这些函数可以轻松地在不同的硬件上执行,包括多个GPU。这种灵活性使得分布式训练变得更为简单和高效。
-
自动微分:JAX支持自动微分,可以自动计算函数关于其输入的导数。这对于深度学习训练至关重要,因为梯度下降等优化算法需要用到这些导数来更新模型的参数。JAX的自动微分功能可以无缝地与多GPU分布式训练结合使用。
-
XLA支持:JAX依赖于XLA(Accelerated Linear Algebra),这是一个由Google开发的用于优化机器学习计算的线性代数库。XLA支持JIT(Just-In-Time)编译和跨设备(CPU/GPU/TPU)执行,使得JAX代码可以在多个GPU上高效运行。通过XLA,JAX可以将Python和JAX代码编译成优化的内核,这些内核可以在GPU上执行,从而实现高效的分布式训练。
-
分布式策略:虽然JAX本身并不是一个专门的深度学习框架,但它可以与其他框架(如TensorFlow)的分布式训练策略结合使用。例如,可以使用TensorFlow的分布式策略(如MirroredStrategy)来管理多个GPU之间的数据并行和模型并行。这些策略可以确保数据在多个GPU之间均匀分布,并在每个GPU上执行相同的计算任务,从而实现高效的分布式训练。
-
通信机制:在分布式训练中,节点之间的通信是一个关键问题。JAX可以利用TensorFlow等框架提供的通信机制(如gRPC)来实现节点之间的数据交换和同步。这些机制可以确保数据在多个GPU之间高效地传输,从而加速训练过程。
-
扩展性:使用JAX进行多GPU分布式训练的一个优点是它的扩展性。由于JAX是基于Python的库,因此可以轻松地与其他Python库和工具集成使用。这使得研究人员可以灵活地构建和扩展他们的深度学习模型,以适应不同的硬件和分布式训练需求。
-
性能优化:JAX还提供了一些性能优化工具和技术,如向量化、并行化和内存优化等。这些技术可以帮助研究人员进一步提高他们的深度学习模型在多GPU分布式训练中的性能。
使用JAX进行多GPU分布式训练是一种高效且灵活的策略,可以加速大规模深度学习模型的训练过程并提高性能。通过结合JAX的自动微分、XLA支持和分布式策略等功能,研究人员可以轻松地构建和扩展他们的深度学习模型,以适应不同的硬件和分布式训练需求。
1.2 使用JAX进行分布式训练的流程
使用JAX进行多GPU分布式训练的流程通常涉及以下几个关键步骤:
-
环境准备:
- 安装JAX库和相关的依赖项,如NumPy等。
- 配置多GPU环境,确保所有的GPU都可以被正确地识别和使用。
-
数据准备:
- 准备训练数据集,并进行适当的预处理。
- 在多个GPU之间分割数据,实现数据并行化。这可以通过使用JAX的API或者与JAX兼容的深度学习框架(如TensorFlow)的分布式策略来实现。
-
模型定义:
- 使用JAX或兼容的深度学习框架定义深度学习模型。
- 确保模型可以支持多GPU训练,例如通过定义模型参数为可分布在多个GPU上的数据结构。
-
分布式策略配置:
- 选择合适的分布式策略,如数据并行化或模型并行化。
- 配置分布式策略的参数,如每个GPU处理的数据量、模型切片的分配等。
-
训练循环:
- 编写训练循环代码,使用JAX的autograd模块来计算梯度,并使用优化器更新模型参数。
- 在训练循环中,确保数据被正确地分发到多个GPU上,并在每个GPU上执行前向传播、反向传播和参数更新等操作。
- 在GPU之间同步梯度更新,以确保所有GPU上的模型参数保持一致。这可以通过使用JAX的通信机制(如All-Reduce操作)来实现。
-
性能优化:
- 使用JAX提供的性能优化工具和技术来加速训练过程。例如,可以使用JAX的JIT编译功能来加速计算,或者使用向量化操作来减少内存访问次数。
- 监控训练过程中的性能指标,如训练速度、GPU利用率等,并根据需要进行调整和优化。
-
模型评估与保存:
- 在训练过程中定期评估模型的性能,例如使用验证集或测试集来计算准确率和损失等指标。
- 在模型达到预设的性能要求时,将其保存为文件以便后续使用。
-
错误处理和日志记录:
- 在训练过程中添加适当的错误处理和日志记录机制,以便在出现问题时能够快速定位和解决。
需要注意的是,具体的实现细节可能会因所使用的深度学习框架、硬件环境以及模型复杂度等因素而有所不同。因此,在实际应用中,可能需要根据具体情况进行调整和优化。此外,JAX作为一个较为底层的库,可能需要与其他库和工具(如TensorFlow、NumPy等)结合使用以实现完整的多GPU分布式训练流程。
2、基于JAX的多GPU分布式训练
2.1 本文讨论的内容范围
在多设备之间分配计算通常有两种方式:
数据并行性,这种方式是在多个设备或机器上复制单个模型。每个设备或机器处理不同的数据批次,然后它们合并各自的结果。这种设置有许多变体,这些变体在模型的不同副本如何合并结果、是否在每个批次之后保持同步或是否更松散地耦合等方面有所不同。
模型并行性,这种方式是在不同的设备上运行单个模型的不同部分,共同处理单个数据批次。这种方法最适合具有自然并行架构的模型,例如具有多个分支的模型。
本问主要讨论数据并行性,特别是同步数据并行性,其中模型的不同副本在它们处理每个批次后保持同步。同步性使模型的收敛行为与单设备训练时看到的相同。
具体来说,本文将讨论如何使用jax.sharding API在多个GPU或TPU(通常在单个机器上安装2到16个)上训练Keras模型,同时尽可能少地修改代码。这是研究人员和小规模行业工作流程中最常见的设置。
2.2 系统设置
import os
os.environ["KERAS_BACKEND"] = "jax"
import jax
import numpy as np
import tensorflow as tf
import keras
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=24,
kernel_size=6,
use_bias=False,
strides=2,
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=32,
kernel_size=6,
padding="same",
strides=2,
name="large_k",
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
return model
def get_datasets():
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# Create TF Datasets
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
return train_data, eval_data
2.3 单主机多设备同步训练
在这种设置中,你有一台机器,上面配备了多个GPU或TPU(通常是2到16个)。每个设备将运行你的模型的一个副本(称为一个副本)。为了简化说明,以下我们将假设我们处理的是8个GPU。
2.3.1 工作原理
在训练的每一步中:
- 当前的数据批次(称为全局批次)被分割成8个不同的子批次(称为本地批次)。例如,如果全局批次有512个样本,那么每个本地批次将包含64个样本。
- 这8个副本各自独立地处理一个本地批次:它们运行前向传播,然后进行反向传播,输出模型在本地批次上的损失相对于权重的梯度。
- 来自本地梯度的权重更新在8个副本之间高效合并。因为这是在每一步结束时完成的,所以副本之间总是保持同步。
在实践中,模型副本的权重同步更新是在每个单独的权重变量级别上处理的。这是通过使用配置为复制变量的jax.sharding.NamedSharding
来完成的。
2.3.2 使用方法
要使用Keras模型进行单主机、多设备同步训练,程序员将使用jax.sharding功能。以下是其工作原理:
- 首先使用
mesh_utils.create_device_mesh
创建一个设备网格。 - 使用
jax.sharding.Mesh
、jax.sharding.NamedSharding
和jax.sharding.PartitionSpec
来定义如何对JAX数组进行分区。 -
- 我们通过使用不包含轴的规范来指定我们想要在所有设备上复制模型和优化器变量。- 我们通过使用在批次维度上进行分割的规范来指定我们想要将数据分片到设备上。
- 使用
jax.device_put
在开始时一次性地将模型和优化器变量复制到各个设备上。 - 在训练循环中,对于我们处理的每个批次,我们在调用训练步骤之前使用
jax.device_put
将批次分割到各个设备上。
以下是流程,其中每一步都被拆分成自己的实用函数:
# Config
num_epochs = 2
batch_size = 64
train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)
model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)
# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
y_pred, updated_non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss_value = loss(y, y_pred)
return loss_value, updated_non_trainable_variables
# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
(loss_value, non_trainable_variables), grads = compute_gradients(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
return loss_value, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
# All variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=("_"))
# In NamedSharding, axes not mentioned are replicated (all axes here)
var_replication = NamedSharding(var_mesh, P())
# Apply the distribution settings to the model variables
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
non_trainable_variables = jax.device_put(
model.non_trainable_variables, var_replication
)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
# Combine all state in a tuple
return (trainable_variables, non_trainable_variables, optimizer_variables)
num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))
# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
) # naming axes of the sharded partition
# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))
train_state = get_replicated_train_state(devices)
# Custom training loop
for epoch in range(num_epochs):
data_iter = iter(train_data)
for data in data_iter:
x, y = data
sharded_x = jax.device_put(x.numpy(), data_sharding)
loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
print("Epoch", epoch, "loss:", loss_value)
# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
3、总结
jax.sharding
是 JAX (Just-in-time Accelerated XLA) 框架中的一个功能,用于支持在多个设备(如GPU或TPU)上分布式地处理数据和模型参数。在深度学习和大规模计算中,当你有多个计算设备可用时,sharding
(分片)是一个将数据或参数划分成多个片段,并将每个片段分配给一个或多个设备的技术。
jax.sharding
允许用户指定如何将数据或参数分片到这些设备上,并且为这种分布式处理提供了一些原语。这对于大规模训练特别有用,因为它允许你有效地利用多个设备上的内存和计算能力。
具体来说,jax.sharding
提供了以下功能:
-
定义分片策略:你可以使用
jax.sharding.PartitionSpec
来定义如何对数组进行分片。例如,你可以指定一个数组应该沿着其某个维度进行分片。 -
创建分片数组:使用
jax.sharding
的功能,你可以创建分片数组,这些数组在物理上分布在多个设备上,但在逻辑上被视为一个整体。 -
合并和分割:
jax.sharding
还提供了用于合并和分割分片数组的功能,以便在多个设备之间移动数据。 -
自动微分:尽管分片发生在物理层面上,但 JAX 的自动微分功能仍然可以无缝地工作,就像你正在处理一个普通的未分片数组一样。
-
模型复制:在同步训练中,你可能希望将模型的参数复制到所有设备上。
jax.sharding
提供了工具来确保这一点,并确保在更新时所有副本都保持同步。 -
与Keras和其他框架的集成:虽然 JAX 本身是一个独立的库,但它与 TensorFlow 的生态系统(包括 Keras)有很好的集成。因此,你可以使用
jax.sharding
来加速使用 Keras 构建的模型的训练。
总的来说,jax.sharding
提供了一种在多个设备上有效地分布式处理数据和模型参数的方法,这对于加速深度学习训练特别有用。
在单主机、多设备(如GPU或TPU)的同步训练设置中,可以使用JAX的sharding功能来在多个设备上并行训练Keras模型。该过程主要包括以下步骤:
-
创建设备网格:使用
mesh_utils.create_device_mesh
来创建一个设备网格,代表你的机器上的所有可用设备。 -
定义分区策略:使用
jax.sharding.Mesh
、jax.sharding.NamedSharding
和jax.sharding.PartitionSpec
来定义如何对JAX数组进行分区。这包括指定模型和优化器变量在所有设备上的复制策略(无轴规范),以及数据在设备上的分片策略(沿批次维度分割)。 -
初始化模型和优化器:使用
jax.device_put
将模型和优化器变量复制到所有设备上。这通常在训练开始前进行一次。 -
数据批次分片:在训练循环中,对于每个数据批次,使用
jax.device_put
将其分割到各个设备上,以便每个设备可以独立处理一部分数据。 -
同步训练:在每个设备上,模型副本都会独立处理其本地数据批次,执行前向传播和反向传播,并计算损失对权重的梯度。这些梯度会被高效地合并以同步更新所有模型副本的权重,从而确保它们保持一致。
-
循环迭代:训练循环将不断迭代数据,每个批次都会重复上述分片、前向传播、反向传播和权重更新的过程,直到达到预定的训练轮次或满足其他停止条件。
通过使用JAX的sharding功能,可以很容易地在多个设备上并行训练Keras模型,而无需对原始代码进行大量修改。这种设置对于研究人员和小规模工业应用来说非常常见,因为它可以有效地利用多设备计算能力来加速训练过程。