TensorFlow实现半同步训练

一、半同步例子:

#coding=utf-8

#python sync_dist_train.py --job_name=ps --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=1 --issync=1

import time
import numpy as np
import tensorflow as tf

# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
                     'Steps to validate and print loss')

# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "127.0.0.1:2222",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "127.0.0.1:2224,127.0.0.1:2225,127.0.0.1:2226",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("issync", 0, "是否采用分布式的同步模式,1表示同步模式,0表示异步模式")


# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate


def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)

  worker_count = len(worker_hosts)

  issync = FLAGS.issync
  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":
    with tf.device(tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
      global_step = tf.Variable(0, name='global_step', trainable=False)

      X = tf.placeholder(tf.float32)
      Y = tf.placeholder(tf.float32)
      w = tf.Variable(0.0, name="weight")
      b = tf.Variable(0.0, name="reminder")
      y = w * X + b

      loss = tf.reduce_mean(tf.square(y - Y))
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)

      if issync == 1:
        #同步模式
        optimizer = tf.train.SyncReplicasOptimizer(optimizer,
                                                replicas_to_aggregate=len(worker_hosts),
                                                total_num_replicas=len(worker_hosts),
                                                use_locking=True)
        sync_replicas_hook = optimizer.make_session_run_hook(FLAGS.task_index == 0)

      #更新梯度
      train_op = optimizer.minimize(loss, global_step=global_step)

      hooks = [tf.train.StopAtStepHook(last_step=1000000)]
      if issync == 1:
          hooks.append(sync_replicas_hook)
      else:
          semiSyncRunHook = tf.train.SemiSyncRunHook(FLAGS.task_index,worker_count=worker_count,staleness=10)
          hooks.append(semiSyncRunHook)

      with tf.train.MonitoredTrainingSession(
        master=server.target, is_chief=(FLAGS.task_index == 0),
        checkpoint_dir="./train_logs", hooks=hooks) as mon_sess:
          while not mon_sess.should_stop():
            train_x = np.random.randn(1)
            train_y = 2 * train_x + np.random.randn(1) * 0.33  + 10
            _, loss_v, step = mon_sess.run([train_op, loss,global_step], feed_dict={X:train_x, Y:train_y})
            if step % steps_to_validate == 0:
              w_,b_ = mon_sess.run([w,b])
              print("step: %d, weight: %f, biase: %f, loss: %f" %(step, w_, b_, loss_v))



if __name__ == "__main__":
  tf.app.run()

二、半同步训练实现代码:

  
@tf_export("train.SemiSyncRunHook")
class SemiSyncRunHook(session_run_hook.SessionRunHook):
  """Run by SSP."""

  def __init__(self, index, worker_count, staleness=10):
    """Initializes a `SemiSyncRunHook`.

    Args:
      index: work index
      worker_count: number of workers
      staleness: 
    """

    if index >= worker_count:
      print("worker index {} is bigger than worker_count {}".format(index, worker_count))
      return

    self._const_max_test_step = 10000
    self._last_step = 0  #上一次wait的步骤数
    self._last_time = self._now_time()  # 上一次wait的时间

    self._index = index
    self._staleness = staleness
    self._wait_time = 0.01 #等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关
    self._worker_steps = [] # 记录worker训练步骤数的变量列表

    for i in range(worker_count):
      worker_step = variable_scope.variable(0,trainable=False,name="worker_step_"+str(i))
      self._worker_steps.append(worker_step)
      if i == index:
        self._my_step_update_op = state_ops.assign_add(worker_step, 1)

    self._initialize_op = variables.variables_initializer(self._worker_steps)

  def _now_time(self): #s
    return time.time()

  def after_create_session(self, session, coord):
    session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量

  def before_run(self, run_context):
    run_context.session.run(self._my_step_update_op)  # 更新本worker的训练步骤数
    return None

  def after_run(self,run_context,run_values):
    while True:
      # 1.获取所有worker的训练步骤数
      all_worker_steps = run_context.session.run(self._worker_steps)
      # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))

      # 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break;
      if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness:
        diff_step = all_worker_steps[self._index] - self._last_step
        if diff_step / self._const_max_test_step > 1:
          self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7

          # 更新
          self._last_step = all_worker_steps[self._index]
          self._last_time = self._now_time()

        time.sleep(self._wait_time) #等待慢worker执行
        # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time))
      else:
        break

三、半同步训练原理:

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值