深度学习分布式训练实战(二)——TF

本篇博客主要介绍TF的分布式训练,重点从代码层面进行讲解。理论部分可以参考深度学习分布式训练实战(一)

TF的分布式实现方式

TF的分布式有两种实现方式,一种是图内分布式(In-graph replication);一种是图间分布式(Between-graph replication)。图内分布式中,计算图只有一个,需要一个中心节点分配计算任务并更新参数,由于中心节点的存在,中心节点容易成为瓶颈。图间分布式中,计算图有多个,但是不同计算图的相同变量通过tf.train.replica_device_setter函数放到同一个服务器上,这种情况下,各个计算图相互独立(参数只有一份,计算图有多个),并行度更高,适合异步更新,同步更新下相对麻烦,不过TF给了接口tf.train.SyncReplicasOptimizer函数来帮助实现参数的同步更新,所以图间分布式应用相对广泛一些。
关于数据并行,模型并行可以参考深度学习分布式训练实战(一)
大部分情况下,我们使用图间分布式,图内分布式一般只会在模型太大的情况下使用。对于图间分布式,其基于gRPC通信框架,模型参数只有一份,计算图有多份,一个master负责创建主session,多个worker执行计算图任务。模型训练过程中,每个计算图计算出各自梯度,然后对参数进行更新。更新方式有两种:同步更新,异步更新。

分布式TF中,TF需要建立一个集群,然后在集群中建立两个job,一个是ps job,负责参数初始化,参数更新,一个job下面可以有多个task(有多个task,说明有多台机器,或者GPU负责参数初始化,更新)。一个是woker job,负责计算图的运算,计算梯度,一个worker job下面也可以有很多个task(有多个task,说明有多台机器,或者GPU负责运行计算图)。

参数异步更新的分布式训练

参数同步更新基本上和这里写的差不多TensorFlow分布式部署
。只不过为了方便在本机上调试,所以改了一点点。(自己的笔记本没有GPU),介绍下面几个重点的语句:
tf.train.ClusterSpec():创建一个集群对象
tf.train.Server():在这个集群上面创建一个服务器,根据实际情况,可以是参数服务器,也可以是计算服务器
tf.train.Supervisor():创建一个监视器,就是用来监控训练过程的,个人感觉主要就是方便恢复模型训练,其logdir参数为训练日志目录,如果里面有模型,则直接恢复训练。所以如果想重新训练,需要删除这个目录。
sv.managed_session():启动Session,相比于其他启动Session的方式,多了一些功能。可以参考TensorFlow 中三种启动图用法

具体代码如下:

# tensorflow distribute train by asynchronously update 

import tensorflow as tf
import numpy as np

tf.app.flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "", "one of ps or worker")
tf.app.flags.DEFINE_integer("task_index", 0, "0, 1, 2...")

FLAGS = tf.app.flags.FLAGS

def main(_):
	ps_hosts = FLAGS.ps_hosts.split(",")
	worker_hosts = FLAGS.worker_hosts.split(",")

	# Create a cluster from the parameter server and worker server
	cluster = tf.train.ClusterSpec({
   "ps":ps_hosts, "worker":worker_hosts})

	# Create and start a server for the local task
	server = tf.train.Server(cluster, job_name = FLAGS.job_name, task_index=FLAGS.task_index)
    # 如果是参数服务器,则直接阻塞,等待计算服务器下达参数初始化,参数更新命令就可以了。
    # 不过“下达命令”这个是TF内部实现的,没有显式实现
	if FLAGS.job_name == "ps":
		server.join() 
	
  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值