tf.cond and tf.whileloop
本篇博客梳理一下 tensorflow python client API 与tf.cond和 tf.whileloop相关的部分,对于这两个api 的详细解释,请看我的第一篇博客和第二篇博客,这篇博客讲一些api 的实现细节。
API 梳理
tf.cond
下图作为讲解tf.cond代码时的参考

以下是tf.cond的伪代码。源代码在controlf_flow_ops.py
context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
# Build the graph for the false branch
context_f = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)]
return merges
正如我在第二篇博客提到的那样。
tf.whileloop
下图作为tf.whileloop的参考:

以下是tf.whileloop的伪代码。源代码在controlf_flow_ops.py
while_context = WhileContext()
while_context.Enter()
# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]
# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x, x]) for x in enter_vars]
# Build the loop pred subgraph.
pred_result = pred(*merge_vars)
# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]
# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])
# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]
# Form the cycles for the loop.
for m, v in zip(merge_vars, next_vars):
m.op._update_input(1, v)
# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars
实现细节
由上面的伪代码可知,tf.cond和tf.whileloop实现的时候都是建立一个上下文(context),由此抽象出一个基类ControlFlowContext,来处理cond和whileloop相关的地方。一个控制流上下文主要保存了控制流相关的上下文信息,当在一个上下文里新建一个operation ,operation初始化时会用到这些上下文。比如说,如果在一个whileloop里新建一个operation,而这个whileloop是一个前向传播whileloop的反向传播(详情请参考我的第二篇博客),而且这个operation 需要用到前向传播的迭代中计算出来的值。前先传播中值的生产和反向传播中对值的消费,是一种first in last out 的关系,所以要加入一个栈和相应的入栈和出栈操作,如下图所示。

此外,从一个控制流上下文输出的值也需要做处理、对控制流上下文里op控制依赖边也要做一些调整。这些都依赖控制上下文信息。
class ControlFlowContext(object)
这个基类抽象出对cond 和whileloop 通用的接口。
以下分别介绍cond 和whileloop 两个类。
由上面的两张图,我们可以知道,cond 和 whileloop 的一个重要的相同之处是,它们都要对在这个上下文中创建的op的输入做一些处理,这些处理由addop这个方法完成。下面我主要讲这个方法
这里面的逻辑是,一个op 属于某一个controflow context, controflow context是嵌套的,当一个op 初始化时,对该op 的输入值,调用op 所在ControlFlowContext的addvalues 方法。对于cond,是以传入的输入值为基础添加一个merge节点;对于whileloop,可能需要加入入栈和出栈的操作。
__init__ (python/framework/ops.py 里的operation初始化方法)
-
self._control_flow_post_processing-
self._control_flow_context.AddOp(self)
-
class CondContext(ControlFlowContext)
class WhileContext(ControlFlowContext)

1万+

被折叠的 条评论
为什么被折叠?



