1.Keras 的分布式训练
概述
tf. distribute. Strategy API 提供了一个抽象的 API ,用于跨多个处理单元(processing units)分布式训练。
它的目的是允许用户使用现有模型和训练代码,只需要很少的修改,就可以启用分布式训练。
strategy = tf. distribute. MirroredStrategy( )
INFO: tensorflow: Using MirroredStrategy with devices ( '/job:localhost/replica:0/task:0/device:GPU:0' , )
INFO: tensorflow: Using MirroredStrategy with devices ( '/job:localhost/replica:0/task:0/device:GPU:0' , )
print ( 'Number of devices: {}' . format ( strategy. num_replicas_in_sync) )
Number of devices: 1
在训练具有多个 GPU 的模型时,您可以通过增加批量大小(batch size)来有效地使用额外的计算能力。
通常来说,使用适合 GPU 内存的最大批量大小(batch size),并相应地调整学习速率。
2.多工作器(worker)配置
多工作器(worker)配置
现在让我们进入多工作器(worker) 训练的世界。在 TensorFlow 中,需要 TF_CONFIG 环境变量来训练多台机器,每台机器可能具有不同的角色。 TF_CONFIG用于指定作为集群一部分的每个 worker 的集群配置。
TF_CONFIG 有两个组件:cluster 和 task 。 cluster 提供有关训练集群的信息,这是一个由不同类型的工作组成的字典,例如 worker 。在多工作器(worker)培训中,除了常规的“工作器”之外,通常还有一个“工人”承担更多责任,比如保存检查点和为 TensorBoard 编写摘要文件。这样的工作器(worker)被称为“主要”工作者,习惯上worker 中 index 0 被指定为主要的 worker(事实上这就是tf. distribute. Strategy的实现方式)。 另一方面,task 提供当前任务的信息。
os. environ[ 'TF_CONFIG' ] = json. dumps( {
'cluster' : {
'worker' : [ "localhost:12345" , "localhost:23456" ]
} ,
'task' : { 'type' : 'worker' , 'index' : 0 }
} )
选择正确的策略
在 TensorFlow 中,分布式训练包括同步训练(其中训练步骤跨工作器和副本同步)、异步训练(训练步骤未严格同步)。
MultiWorkerMirroredStrategy 是同步多工作器训练的推荐策略,将在本指南中进行演示。
要训练模型,请使用 tf. distribute. experimental. MultiWorkerMirroredStrategy 的实例。
strategy = tf. distribute. experimental. MultiWorkerMirroredStrategy( )
NUM_WORKERS = 2
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
train_datasets = make_datasets_unbatched( ) . batch( GLOBAL_BATCH_SIZE)
with strategy. scope( ) :
multi_worker_model = build_and_compile_cnn_model( )
multi_worker_model. fit( x= train_datasets, epochs= 3 , steps_per_epoch= 5 )
数据集分片和批(batch)大小
options = tf. data. Options( )
options. experimental_distribute. auto_shard_policy = tf. data. experimental. AutoShardPolicy. OFF
train_datasets_no_auto_shard = train_datasets. with_options( options)
callbacks = [ tf. keras. callbacks. ModelCheckpoint( filepath= '/tmp/keras-ckpt' ) ]
with strategy. scope( ) :
multi_worker_model = build_and_compile_cnn_model( )
multi_worker_model. fit( x= train_datasets,
epochs= 3 ,
steps_per_epoch= 5 ,
callbacks= callbacks)