tensorpack之predict解析

在ga3c中,每个agent在获得当前的state(例如游戏图像)后需要向master请求下一步的动作/V值。一个naive的master实现可以这样(简单起见不考虑epsode_finish):

while True:

    identity, image = input_queue.recv() //agent的标识,和相应帧

    action = sess.run(action, feed_dict = {state:image})

    output_queue.send_multipart((identity, action))


然而,这种实现方式有很大的效率问题:每次向gpu传递的数据过少,无法充分利用资源。


可以用以下方式略微改进:

BATCH_SIZE = 128

count = 0

buff = []

while True:

    id, image = input_queue.recv()

    buff.append(image)

    id_buff.append(id)

    count += 1

    if count < BATCH_SIZE:

        continue;


    actions = sess.run(action, feed_dict = {state:image})

    for i in len(actions):

        output_queue.send_multipart((id_buff[i], action[i]))


累积一个batch的图像,一并送入GPU。这样的效率

比第一段代码提高了一倍,但对GPU的利用仍然只有50%。


tensorpack中,作者使用了predictor解决这个问题。predictor只干一件简单的事:从master输入的队列中提取一个batch的数据,调用GPU计算,然后返回结果。由于predictor是多线程的,效率比上面的代码要好。

从train.atari.py开始看起。在MySimulatorMaster,作者如下定义了predictor:

self.async_predictor = MultiThreadAsyncPredictor(
    self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'],
                                PREDICTOR_THREAD), batch_size=15)

MultiThreadAsyncPredictor是一个控制所有predictor的类。其构造函数如下:

class MultiThreadAsyncPredictor(AsyncPredictorBase):
    """
    An multithread online async predictor which runs a list of OnlinePredictor.
    It would do an extra batching internally.
    """

    def __init__(self, predictors, batch_size=5):
        """
        Args:
            predictors (list): a list of OnlinePredictor avaiable to use.
            batch_size (int): the maximum of an internal batch.
        """
        assert len(predictors)
        self._need_default_sess = False
        for k in predictors:
            assert isinstance(k, OnlinePredictor), type(k)
            if k.sess is None:
                self._need_default_sess = True
            # TODO support predictors.return_input here
            assert not k.return_input
        self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
        self.threads = [
            PredictorWorkerThread(
                self.input_queue, f, id, batch_size=batch_size)
            for id, f in enumerate(predictors)]

        if six.PY2:
            # TODO XXX set logging here to avoid affecting TF logging
            import tornado.options as options
            options.parse_command_line(['--logging=debug'])

拥有self.input_queue和self.threads两个关键的成员变量。

为了观察async_predictor的作用,我们回到MySimulatorMaster类:

def _on_state(self, state, ident):
    def cb(outputs):
        distrib, value = outputs.result()
        assert np.all(np.isfinite(distrib)), distrib
        action = np.random.choice(len(distrib), p=distrib)
        client = self.clients[ident]
        client.memory.append(TransitionExperience(state, action, None, value=value))
        self.send_queue.put([ident, dumps(action)])
    self.async_predictor.put_task([state], cb)

state为某个agent的游戏状态,ident为这个agent的标志。此处定义了一个回调函数,一旦output被计算出来,直接将action传送给send_queue发回给指定agent进程,并把TransitionExperience保存在Master的Memory中,以便进行训练。put_task(concurrency.py)定义了一个Future对象,并为之添加一个callback,最后在input_queue中append一个二元组:(datapoint,callback)。

def put_task(self, dp, callback=None):
    """
    Same as in :meth:`AsyncPredictorBase.put_task`.
    """
    f = Future()
    if callback is not None:
        f.add_done_callback(callback)
    self.input_queue.put((dp, f))
    return f


OK,我们已经理顺了predictor的上下文关系:它从Master的input_queue中获得state,然后把output通过zmq传回client。需要注意的是,它并不是独立进程,而是一个线程。接下来,进一步详细分析predictor的代码。回到第一段代码:

self.async_predictor = MultiThreadAsyncPredictor(
    self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'],
                                PREDICTOR_THREAD), batch_size=15)

MultiThreadAsyncPredictor接受一个predictor的list作为构造参数。get_predictors函数调用工厂类predictor_factory构造一个OnlinePredictor (见predict/base)。它的基类PredictorBase定义了__call__内建函数:

def __call__(self, *args):
    """
    Call the predictor on some inputs.

    If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
    otherwise, assume ``args`` is a datapoinnt

    Examples:
        When you have a predictor which takes a datapoint [e1, e2], you
        can call it in two ways:

        .. code-block:: python

            predictor(e1, e2)
            predictor([e1, e2])
    """
    if len(args) != 1:
        dp = args
    else:
        dp = args[0]
    output = self._do_call(dp)
    if self.return_input:
        return (dp, output)
    else:
        return output
它直接将dp(datapoint)传递给子类的_do_call函数。因而,可以将OnlinePredictor理解为一个函数。这样,我们进一步观察MultiThreadAsyncPredictor构造函数,其成员变量async_predictor为PredictorWorkerThread的列表。PredictorWorkerThread才是真正执行predict工作的线程。下面列出它的全部代码。


class PredictorWorkerThread(StoppableThread, ShareSessionThread):
    def __init__(self, queue, pred_func, id, batch_size=5):
        super(PredictorWorkerThread, self).__init__()
        self.name = "PredictorWorkerThread-{}".format(id)
        self.queue = queue
        self.func = pred_func
        self.daemon = True
        self.batch_size = batch_size
        self.id = id

    def run(self):
        with self.default_sess():
            while not self.stopped():
                batched, futures = self.fetch_batch()
                try:
                    outputs = self.func(batched)
                except tf.errors.CancelledError:
                    for f in futures:
                        f.cancel()
                    logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(self.id))
                    return
                # print "Worker {} batched {} Queue {}".format(
                #         self.id, len(futures), self.queue.qsize())
                #  debug, for speed testing
                # if not hasattr(self, 'xxx'):
                #     self.xxx = outputs = self.func(batched)
                # else:
                #     outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]

                for idx, f in enumerate(futures):
                    f.set_result([k[idx] for k in outputs])

    def fetch_batch(self):
        """ Fetch a batch of data without waiting"""
        inp, f = self.queue.get()
        nr_input_var = len(inp)
        batched, futures = [[] for _ in range(nr_input_var)], []
        for k in range(nr_input_var):
            batched[k].append(inp[k])
        futures.append(f)
        while len(futures) < self.batch_size:
            try:
                inp, f = self.queue.get_nowait()
                for k in range(nr_input_var):
                    batched[k].append(inp[k])
                futures.append(f)
            except queue.Empty:
                break   # do not wait
        return batched, futures

构造函数中,queue即为MySimulatorMaster中定义的输入队列,pred_function就是OnlinePredictor (前面说过,这个类可以被当作函数进行调用)。fatch_batch是构造一个batch的数据,作为pred_function的输入,还是比较容易看懂的。

最后是核心的run()。首先调用OnlinePredictor对每个batched进行预测,然后对每个futures列表中的元素(tornado的Future对象)进行set_result操作。至此,程序返回到MysimulatorMaster定义的回调函数,整个predict的流程结束。


通过以上分析,我们可以看出predictor并非独立的进程,不知换成多进程效率将如何变化。虽然代码看似复杂,但整个工作不难理解。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值