本篇博客主要介绍TF的分布式训练,重点从代码层面进行讲解。理论部分可以参考深度学习分布式训练实战(一)
TF的分布式实现方式
TF的分布式有两种实现方式,一种是图内分布式(In-graph replication);一种是图间分布式(Between-graph replication)。图内分布式中,计算图只有一个,需要一个中心节点分配计算任务并更新参数,由于中心节点的存在,中心节点容易成为瓶颈。图间分布式中,计算图有多个,但是不同计算图的相同变量通过tf.train.replica_device_setter函数放到同一个服务器上,这种情况下,各个计算图相互独立(参数只有一份,计算图有多个),并行度更高,适合异步更新,同步更新下相对麻烦,不过TF给了接口tf.train.SyncReplicasOptimizer函数来帮助实现参数的同步更新,所以图间分布式应用相对广泛一些。
关于数据并行,模型并行可以参考深度学习分布式训练实战(一)
大部分情况下,我们使用图间分布式,图内分布式一般只会在模型太大的情况下使用。对于图间分布式,其基于gRPC通信框架,模型参数只有一份,计算图有多份,一个master负责创建主session,多个worker执行计算图任务。模型训练过程中,每个计算图计算出各自梯度,然后对参数进行更新。更新方式有两种:同步更新,异步更新。
分布式TF中,TF需要建立一个集群,然后在集群中建立两个job,一个是ps job,负责参数初始化,参数更新,一个job下面可以有多个task(有多个task,说明有多台机器,或者GPU负责参数初始化,更新)。一个是woker job,负责计算图的运算,计算梯度,一个worker job下面也可以有很多个task(有多个task,说明有多台机器,或者GPU负责运行计算图)。
参数异步更新的分布式训练
参数同步更新基本上和这里写的差不多TensorFlow分布式部署
。只不过为了方便在本机上调试,所以改了一点点。(自己的笔记本没有GPU),介绍下面几个重点的语句:
tf.train.ClusterSpec()
:创建一个集群对象
tf.train.Server()
:在这个集群上面创建一个服务器,根据实际情况,可以是参数服务器,也可以是计算服务器
tf.train.Supervisor()
:创建一个监视器,就是用来监控训练过程的,个人感觉主要就是方便恢复模型训练,其logdir
参数为训练日志目录,如果里面有模型,则直接恢复训练。所以如果想重新训练,需要删除这个目录。
sv.managed_session()
:启动Session,相比于其他启动Session的方式,多了一些功能。可以参考TensorFlow 中三种启动图用法
具体代码如下:
# tensorflow distribute train by asynchronously update
import tensorflow as tf
import numpy as np
tf.app.flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "", "one of ps or worker")
tf.app.flags.DEFINE_integer("task_index", 0, "0, 1, 2...")
FLAGS = tf.app.flags.FLAGS
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker server
cluster = tf.train.ClusterSpec({
"ps":ps_hosts, "worker":worker_hosts})
# Create and start a server for the local task
server = tf.train.Server(cluster, job_name = FLAGS.job_name, task_index=FLAGS.task_index)
# 如果是参数服务器,则直接阻塞,等待计算服务器下达参数初始化,参数更新命令就可以了。
# 不过“下达命令”这个是TF内部实现的,没有显式实现
if FLAGS.job_name == "ps":
server.join()