参考https://www.tensorflow.org/tutorials/distribute/custom_training
- 使用tf.distribute模块,下面有不同的分布式策略,依赖第三方包解决GPU通信问题,默认是英伟达的nccl(Nvidia collective communication library)
- tf.distribute.MirroredStrategy : 所有GPU都会加载完整的图和变量,并且保持同步,每个batch需要等所有worker完成,再reduce
- CentralStorageStrategy:数据保存在CPU上,GPU上都是从CPU拷贝
- MultiWorkerMirroredStrategy:和MirroredStrategy差不多,只是分布式程度更高,每个worker都可以多GPU
- ParameterServerStrategy多机器的,一些作为参数服务器,其他计算,通过TF_CONFIG 指定ps和server
必须在strategy.scope
下创建模型和优化器。
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
with strategy.scope():
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
with strategy.scope():
# `experimental_run_v2`将复制提供的计算并使用分布式输入运行它。
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.experimental_run_v2(train_step,
args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# 训练循环
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# 测试循环
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print (template.format(epoch+1, train_loss,
train_accuracy.result()*100, test_loss.result(),
test_accuracy.result()*100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()