tensorflow系列(3)分布式tensorflow

多机如何分布式运行tensorflow模型?

(原文发表在我的博客,欢迎访问

0x00.前言

对于比较复杂的模型,在本机或者单服务器上跑起来需要很长时间。在很多科研单位或公司,可能没有插满gpu的服务器,这时候怎么办呢,有没有可能多台服务器一起跑一个模型呢?

这里就要用到分布式的tensorflow(distributed tensorflow)。

下面介绍在集群上部署tensorflow的方法。

0x01.基本概念

在分布式tensorflow中,服务器被分为两类,一类叫做参数服务器(parameter server,简称ps),另一类叫做计算服务器(worker)。顾名思义,ps会存储参数,分发参数;而worker运行模型,与ps就参数进行交互。

1.训练方式

tensorflow中常用的并行化训练方式有同步模式和异步模式两种方式。

在同步模式中,worker同时读取参数,但是训练完成后不会单独对参数进行更新,而是等待所有worker运行完,统一更新参数。

而在异步训练中,不同worker会对参数独立的更新。

0x02.tensorflow官方示例

tensorflow的官方代码在https://github.com/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py,下面我给示例代码打了一些注释,有条件的朋友可以尝试跑一下

1.变量设置

首先设置tf.app.flags定义标记,在命令行执行时,可指定相应参数的值。

import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS

是否开启同步并行。

flags.DEFINE_boolean("sync_replicas", True,
                     "Use the sync_replicas (synchronized replicas) mode, "
                     "wherein the parameter updates from workers are aggregated "
                     "before applied to avoid stale gradients")

在多少个batch后更新模型的参数(在同步更新中)。

flags.DEFINE_integer("replicas_to_aggregate", None,
                     "Number of replicas to aggregate before parameter update"
                     "is applied (For sync_replicas mode only; default: "
                     "num_workers)")

ps服务器、worker服务器地址的设置信息。

flags.DEFINE_string("ps_hosts","10.10.19.7:2222",
                    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",
                    "Comma-separated list of hostname:port pairs")

job_name、task_index的定义,通常是通过命令行指定,不需要手动填写。

flags.DEFINE_string("job_name", None,"job name: worker or ps")
flags.DEFINE_integer("task_index", None,
                     "Worker task index, should be >= 0. task_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")

判断是否填写job_name、task_index。

if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)

从变量中解析ps、worker服务器。

ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
num_workers = len(worker_spec)

2.分布式配置

创建tf中的cluster对象以及server:

cluster = tf.train.ClusterSpec({
  "ps": ps_spec,"worker": worker_spec})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

# 判断是否为主节点
is_chief = (FLAGS.task_index == 0)

计算资源配置,这里仅使用cpu。如果是ps服务器,则只需要等待worker服务器工作即可。

if FLAGS.job_name == "ps":
    server.join()
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

资源配置

with tf.device(tf.train.replica_device_setter(
          worker_device=worker_device,
          ps_device="/job:ps/cpu:0",
          cluster=cluster)):

3.训练准备

全局步数记录

global_step = tf.Variable(
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值