在ga3c中,每个agent在获得当前的state(例如游戏图像)后需要向master请求下一步的动作/V值。一个naive的master实现可以这样(简单起见不考虑epsode_finish):
while True:
identity, image = input_queue.recv() //agent的标识,和相应帧
action = sess.run(action, feed_dict = {state:image})
output_queue.send_multipart((identity, action))
然而,这种实现方式有很大的效率问题:每次向gpu传递的数据过少,无法充分利用资源。
可以用以下方式略微改进:
BATCH_SIZE = 128
count = 0
buff = []
while True:
id, image = input_queue.recv()
buff.append(image)
id_buff.append(id)
count += 1
if count < BATCH_SIZE:
continue;
actions = sess.run(action, feed_dict = {state:image})
for i in len(actions):
output_queue.send_multipart((id_buff[i], action[i]))
累积一个batch的图像,一并送入GPU。这样的效率
比第一段代码提高了一倍,但对GPU的利用仍然只有50%。
tensorpack中,作者使用了predictor解决这个问题。predictor只干一件简单的事:从master输入的队列中提取一个batch的数据,调用GPU计算,然后返回结果。由于predictor是多线程的,效率比上面的代码要好。
从train.atari.py开始看起。在MySimulatorMaster,作者如下定义了predictor:
self.async_predictor = MultiThreadAsyncPredictor( self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'], PREDICTOR_THREAD), batch_size=15)
MultiThreadAsyncPredictor是一个控制所有predictor的类。其构造函数如下:
class MultiThreadAsyncPredictor(AsyncPredictorBase): """ An multithread online async predictor which runs a list of OnlinePredictor. It would do an extra batching internally. """ def __init__(self, predictors, batch_size=5): """ Args: predictors (list): a list of OnlinePredictor avaiable to use. batch_size (int): the maximum of an internal batch. """ assert len(predictors) self._need_default_sess = False for k in predictors: assert isinstance(k, OnlinePredictor), type(k) if k.sess is None: self._need_default_sess = True # TODO support predictors.return_input here assert not k.return_input self.input_queue = queue.Queue(maxsize=len(predictors) * 100) self.threads = [ PredictorWorkerThread( self.input_queue, f, id, batch_size=batch_size) for id, f in enumerate(predictors)] if six.PY2: # TODO XXX set logging here to avoid affecting TF logging import tornado.options as options options.parse_command_line(['--logging=debug'])
拥有self.input_queue和self.threads两个关键的成员变量。
为了观察async_predictor的作用,我们回到MySimulatorMaster类:
def _on_state(self, state, ident): def cb(outputs): distrib, value = outputs.result() assert np.all(np.isfinite(distrib)), distrib action = np.random.choice(len(distrib), p=distrib) client = self.clients[ident] client.memory.append(TransitionExperience(state, action, None, value=value)) self.send_queue.put([ident, dumps(action)]) self.async_predictor.put_task([state], cb)
state为某个agent的游戏状态,ident为这个agent的标志。此处定义了一个回调函数,一旦output被计算出来,直接将action传送给send_queue发回给指定agent进程,并把TransitionExperience保存在Master的Memory中,以便进行训练。put_task(concurrency.py)定义了一个Future对象,并为之添加一个callback,最后在input_queue中append一个二元组:(datapoint,callback)。
def put_task(self, dp, callback=None): """ Same as in :meth:`AsyncPredictorBase.put_task`. """ f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((dp, f)) return f
OK,我们已经理顺了predictor的上下文关系:它从Master的input_queue中获得state,然后把output通过zmq传回client。需要注意的是,它并不是独立进程,而是一个线程。接下来,进一步详细分析predictor的代码。回到第一段代码:
self.async_predictor = MultiThreadAsyncPredictor( self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'], PREDICTOR_THREAD), batch_size=15)
MultiThreadAsyncPredictor接受一个predictor的list作为构造参数。get_predictors函数调用工厂类predictor_factory构造一个OnlinePredictor (见predict/base)。它的基类PredictorBase定义了__call__内建函数:
def __call__(self, *args): """ Call the predictor on some inputs. If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list). otherwise, assume ``args`` is a datapoinnt Examples: When you have a predictor which takes a datapoint [e1, e2], you can call it in two ways: .. code-block:: python predictor(e1, e2) predictor([e1, e2]) """ if len(args) != 1: dp = args else: dp = args[0] output = self._do_call(dp) if self.return_input: return (dp, output) else: return output它直接将dp(datapoint)传递给子类的_do_call函数。因而,可以将OnlinePredictor理解为一个函数。这样,我们进一步观察MultiThreadAsyncPredictor构造函数,其成员变量async_predictor为PredictorWorkerThread的列表。PredictorWorkerThread才是真正执行predict工作的线程。下面列出它的全部代码。
class PredictorWorkerThread(StoppableThread, ShareSessionThread): def __init__(self, queue, pred_func, id, batch_size=5): super(PredictorWorkerThread, self).__init__() self.name = "PredictorWorkerThread-{}".format(id) self.queue = queue self.func = pred_func self.daemon = True self.batch_size = batch_size self.id = id def run(self): with self.default_sess(): while not self.stopped(): batched, futures = self.fetch_batch() try: outputs = self.func(batched) except tf.errors.CancelledError: for f in futures: f.cancel() logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(self.id)) return # print "Worker {} batched {} Queue {}".format( # self.id, len(futures), self.queue.qsize()) # debug, for speed testing # if not hasattr(self, 'xxx'): # self.xxx = outputs = self.func(batched) # else: # outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])] for idx, f in enumerate(futures): f.set_result([k[idx] for k in outputs]) def fetch_batch(self): """ Fetch a batch of data without waiting""" inp, f = self.queue.get() nr_input_var = len(inp) batched, futures = [[] for _ in range(nr_input_var)], [] for k in range(nr_input_var): batched[k].append(inp[k]) futures.append(f) while len(futures) < self.batch_size: try: inp, f = self.queue.get_nowait() for k in range(nr_input_var): batched[k].append(inp[k]) futures.append(f) except queue.Empty: break # do not wait return batched, futures
构造函数中,queue即为MySimulatorMaster中定义的输入队列,pred_function就是OnlinePredictor (前面说过,这个类可以被当作函数进行调用)。fatch_batch是构造一个batch的数据,作为pred_function的输入,还是比较容易看懂的。
最后是核心的run()。首先调用OnlinePredictor对每个batched进行预测,然后对每个futures列表中的元素(tornado的Future对象)进行set_result操作。至此,程序返回到MysimulatorMaster定义的回调函数,整个predict的流程结束。
通过以上分析,我们可以看出predictor并非独立的进程,不知换成多进程效率将如何变化。虽然代码看似复杂,但整个工作不难理解。