多机如何分布式运行tensorflow模型?
(原文发表在我的博客,欢迎访问
0x00.前言
对于比较复杂的模型,在本机或者单服务器上跑起来需要很长时间。在很多科研单位或公司,可能没有插满gpu的服务器,这时候怎么办呢,有没有可能多台服务器一起跑一个模型呢?
这里就要用到分布式的tensorflow(distributed tensorflow)。
下面介绍在集群上部署tensorflow的方法。
0x01.基本概念
在分布式tensorflow中,服务器被分为两类,一类叫做参数服务器(parameter server,简称ps),另一类叫做计算服务器(worker)。顾名思义,ps会存储参数,分发参数;而worker运行模型,与ps就参数进行交互。
1.训练方式
tensorflow中常用的并行化训练方式有同步模式和异步模式两种方式。
在同步模式中,worker同时读取参数,但是训练完成后不会单独对参数进行更新,而是等待所有worker运行完,统一更新参数。
而在异步训练中,不同worker会对参数独立的更新。
0x02.tensorflow官方示例
tensorflow的官方代码在https://github.com/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py,下面我给示例代码打了一些注释,有条件的朋友可以尝试跑一下
1.变量设置
首先设置tf.app.flags定义标记,在命令行执行时,可指定相应参数的值。
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
是否开启同步并行。
flags.DEFINE_boolean("sync_replicas", True,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
在多少个batch后更新模型的参数(在同步更新中)。
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
ps服务器、worker服务器地址的设置信息。
flags.DEFINE_string("ps_hosts","10.10.19.7:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",
"Comma-separated list of hostname:port pairs")
job_name、task_index的定义,通常是通过命令行指定,不需要手动填写。
flags.DEFINE_string("job_name", None,"job name: worker or ps")
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
判断是否填写job_name、task_index。
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
从变量中解析ps、worker服务器。
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
num_workers = len(worker_spec)
2.分布式配置
创建tf中的cluster对象以及server:
cluster = tf.train.ClusterSpec({
"ps": ps_spec,"worker": worker_spec})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 判断是否为主节点
is_chief = (FLAGS.task_index == 0)
计算资源配置,这里仅使用cpu。如果是ps服务器,则只需要等待worker服务器工作即可。
if FLAGS.job_name == "ps":
server.join()
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
资源配置
with tf.device(tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
3.训练准备
全局步数记录
global_step = tf.Variable(