tensorflow control flow 4 ---python client api之tf.cond and tf.whileloop

tf.cond and tf.whileloop

本篇博客梳理一下 tensorflow python client API 与tf.cond和 tf.whileloop相关的部分,对于这两个api 的详细解释,请看我的第一篇博客和第二篇博客,这篇博客讲一些api 的实现细节。

API 梳理

tf.cond

下图作为讲解tf.cond代码时的参考

以下是tf.cond的伪代码。源代码在controlf_flow_ops.py

context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
# Build the graph for the false branch
context_f = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)]

return merges

正如我在第二篇博客提到的那样。

tf.whileloop

下图作为tf.whileloop的参考:

tensorflow frame

以下是tf.whileloop的伪代码。源代码在controlf_flow_ops.py

while_context = WhileContext()
while_context.Enter()
# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]
# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x, x]) for x in enter_vars]
# Build the loop pred subgraph.
pred_result = pred(*merge_vars)
# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]
# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])
# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]
# Form the cycles for the loop.
for m, v in zip(merge_vars, next_vars):
m.op._update_input(1, v)
# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()

return exit_vars

实现细节

由上面的伪代码可知,tf.cond和tf.whileloop实现的时候都是建立一个上下文(context),由此抽象出一个基类ControlFlowContext,来处理cond和whileloop相关的地方。一个控制流上下文主要保存了控制流相关的上下文信息,当在一个上下文里新建一个operation ,operation初始化时会用到这些上下文。比如说,如果在一个whileloop里新建一个operation,而这个whileloop是一个前向传播whileloop的反向传播(详情请参考我的第二篇博客),而且这个operation 需要用到前向传播的迭代中计算出来的值。前先传播中值的生产和反向传播中对值的消费,是一种first in  last out 的关系,所以要加入一个栈和相应的入栈和出栈操作,如下图所示。

此外,从一个控制流上下文输出的值也需要做处理、对控制流上下文里op控制依赖边也要做一些调整。这些都依赖控制上下文信息。

class ControlFlowContext(object)

这个基类抽象出对cond 和whileloop 通用的接口。

以下分别介绍cond 和whileloop 两个类。

由上面的两张图,我们可以知道,cond 和 whileloop 的一个重要的相同之处是,它们都要对在这个上下文中创建的op的输入做一些处理,这些处理由addop这个方法完成。下面我主要讲这个方法

这里面的逻辑是,一个op 属于某一个controflow context, controflow context是嵌套的,当一个op 初始化时,对该op 的输入值,调用op 所在ControlFlowContext的addvalues 方法。对于cond,是以传入的输入值为基础添加一个merge节点;对于whileloop,可能需要加入入栈和出栈的操作。

__init__ (python/framework/ops.py 里的operation初始化方法)

  • self._control_flow_post_processing
    • self._control_flow_context.AddOp(self)

class CondContext(ControlFlowContext)

class WhileContext(ControlFlowContext)

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值