tensorflow control flow 2---the implementation of control flow

    tensorflow control flow 2---the implementation of control flow

目录

    tensorflow control flow 2---the implementation of control flow

   Control-Flow Primitives

   Compilation of Control-Flow Constructs

The cond Operator

The while_loop Operator

实现细节

excutor

分布式执行

反向传递求导


        在使用tensor flow的过程中,我发现TF 有类似于 编程语言三元操作符  a=predate?b:c 的tf.cond ,类似与 while 的 tf.while_loop,以及其他一些控制流操作。对这个不了解的可以看我的上一篇博客。

        tensorflow和spark 有些类似。spark 的rdd 之间存在着依赖(每个rdd 有lineage信息),形成一个以数据为节点,以计算为边的DAG图,解决数据量大而无法单机处理的问题;tensorflow的op之间存在着依赖(通过inputs/outputs),形成一个以计算为节点,以数据流为边的DAG图,解决计算量大而无法单机处理的问题。把一个DAG图分解成计算子图是比较好理解的。但是Tensorflow引入控制流操作之后,计算图的某些节点不需要计算,同时计算图可能存在循环,这已经不是简单的DAG计算图了如何判断计算哪个分支?如何判断循环的结束?如何把循环分割成计算子图?tf是如何实现控制流的呢?这不禁引起了我的好奇。

        为了满足这份好奇心,我花时间看了一些资料以及部分源代码。随着深入的了解,我发现control flow 就是计算图或者说tensorflow实现的数据流机的核心。我准备从理论和源码两个角度来讲control flow ,这是第一部分,讲原理,主要是来源于四篇文章,1TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems,2TensorFlow: A system for large-scale machine learning,3Implementation of Control Flow in TensorFlow和4Dynamic Control Flow in Large-Scale Machine Learning,博客里的表述很多都是翻译的原文,也包含了一些我阅读源码后的理解,我姑且把这篇归为翻译。起初开始看tensorflow源代码,走了不少弯路,不知从何看起。看完这些论文再去研究源代码,才有些登堂入室的感觉。对tensorflow 计算图感兴趣的同学,可以研究一下。

tensorflow 核心的概念是计算图,而计算图的骨架则是控制流。控制流决定了计算图执行的流程。如果没有控制流,tf计算图就只是简单的DAG图,控制流赋予了tf 计算图比DAG 图(以spark 为代表)更强大的表达能力。

   Control-Flow Primitives

    TF 控制流的基本设计原则是,引入一个比较小的与TF dataflow 兼容的原子操作集,原子操作集能够用来表达一系列控制流程,而且支持并行、分布式和自动求导。

 

        在tensorflow,一个计算节点在执行帧(execution frame,类比进程的栈帧)里执行。控制流原语负责创建和管理执行。直观地理解,TF运行时建立一个个执行帧,在执行帧里执行所有属于这个执行帧的计算节点。执行帧可以嵌套(父子关系)。来自不同执行帧且没有依赖关系的计算节点可以并行计算。(这里面的细节很精妙,详细的会在第二篇博客结合源码讲述)

Switch : 取决于输入p的值,Switch 算子把 d 的值传给两个输出中的某一个 。两个输入都可用,Switch节点才可执行

Merge : A Merge 算子把一个可用的输入传给输出。只要任意一个输入可用,Merge便可执行。

Enter(name) : Enter算子把输入传进被name唯一标识的执行帧 。 Enter 算子用来 把一个tensor从一个执行帧 传给一个子执行帧。对一个子执行帧 可能有多个Enter算子, 每个Enter 算子传入一个tensor。输入可用时,Enter可被执行。对name执行的第一次Enter时,name唯一标识的执行帧被构建.

Exit : Enter 算子用来 把一个tensor从一个执行帧 传给父执行帧。 对一个父执行帧 可能有多个Exit算子, 每个Exit算子传入一个tensor。输入可用时,Enter可被执行。

NextIteration: NextIteration 算子 把输入d 传入当前执行帧的下一个 iteration . TF 运行时维护迭代状态。 执行帧中执行的算子绑定一个iteration id, 用来标识op的一次执行(比如在whileloop里同一个op可能被多次执行)。输入可用时,Enter可被执行。

        
 

   Compilation of Control-Flow Constructs

 

 在引入这五个控制原语之后,例如 cond 和 while_loop这样的高层编程语言结构(high-level programming constructs)可以编译成能被tf运行时执行的数据流图。

 

The cond Operator

下面是用算子表示的构建数据流图 cond(pred, fn1, fn2)的伪代码,为了简单起见,忽略了很多细节。具体的实现在control_flow_ops.py

# Build the graph for the true branch

#pred,fn1,fn2 are lists of tensors

context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
# Build the graph for the false branch
context_f = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)]

return merges

 对于条件的每个分支,一个新的控制流上下文被创建,调用条件上下文的图生成函数。条件上下文使得我们可以捕获外部的tensors 并插入合适的Switch 运算符,以根据条件选择相应的分支。这就保证了当一个条件分支被选用,属于这个分支的op才被执行。由于tensorflow的异步执行模型,各个外部tensors 可能在不同的时间成为可用,每个tensor使用一个Switch运算符可以最大化计算的并行度。

 每个条件分支返回一个tensor的list (ref_t or res_f);然后对于每一对tensor对,添加一个Merge 算子。再次,由于各个输出可能在不同的时间被计算出来,对每一个输出使用一个Merge算子,可以尽快地激活下游的计算。

        让我们来看一条简单的代码片段:

 

如上图生成的数据流图,x,y,z各自Switch 控制输入。由于Switch和Merge的存在,只有当 x<y 时,Add 运算符才被执行;反之,只有当x>y为false时,Square才被执行。取决于条件x<y,最终Merge输出Add的结果或 Square的结果。如上一段提到的那样,如果有多个输出,那么会有多个Merge,每个Merge对应一个输出。

 

The while_loop Operator

下面是用算子表示的构建数据流图 while_loop(pred, body,loop_vars)的伪代码:

while_context = WhileContext()
while_context.Enter()
# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]
# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x, x]) for x in enter_vars]
# Build the loop pred subgraph.
pred_result = pred(*merge_vars)
# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]
# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])
# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]
# Form the cycles for the loop.
for m, v in zip(merge_vars, next_vars):
m.op._update_input(1, v)
# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()

return exit_vars

这整个while loop 计算图是在一个 while loop 控制流上下文中生成。这里的基本的idea是非常简单的。        

     以下面这段代码为例子。

 

以 循环变量为起点,对每一个循环变量,添加一个Enter op ,接着一个Merge op。然后使用Merge的输出构建出pred 子图用来计算循环终止条件。

        之后,添加一个Switch节点,用Switch节点的输出去为循环体构造子图。循环体的结果需要进入下一轮的的迭代,所以循环体的结果通过一个NextIteration 节点链接回到Merge,作为Merge的第二个输入。这形成了一个环状,这允许我们迭代地对一个节点计算多次。

        Switch 节点的假值输出是整个循环体的输出,所以在Switch 假值输出链接了一个Exit 节点。循环(while  loop)上下文计录了条件判断和循环体里用到的外部变量,为每一个外部变量添加一个Enter算子,嵌套的while loop需要添加嵌套的Enter。

以上是控制流大概的介绍,构建条件语句和循环语句的python api 实现,具体实现可以看tensorfow 源代码python部分,在后面的博客我会解读这一部分代码。

实现细节

excutor

计算图的构建是在client 端完成的(前向传递和反向传递),计算图构建完成后,通过session(direct session or distributed session)调用tensorflow runtime,tensorflow 运行时负责执行定义好的计算图。

在tensorflow 运行时看来,计算图是由一系列的whileloop的嵌套,一个whileloop大概是这样的结构。这里假设有两个循环变量,没有循环常量。

tensorflow frame

tf.whileloop(pred,body,[a,b])。pred是接受两个tensor,返回一个标量bool值tensor的函数;body是接受两个tensor,返回两个tensor的函数。

为了便于讲述,以下使用只有一个循环变量的图。

为了在多个设备运行计算图(本地模式是多cpu,分布式情况是多个服务器)tensorflow 自动地把这些ops Node 分配到个设备,并插入一些send/receive 节点对

以及一些协调节点,

这其实就是计算图的子图分割的过程,具体的实现在源码tensorflow/core/graph/graph_partion.cc中的partition函数,会在session run方法里被调用。为什么要加一些协调节点,下文会讲。

一个子图被一个绑定到一个device(cpu/gpu/tpu...etc)的excutor管理和执行。excutor从源节点开始,反复地执行可以执行的节点(上文有讲到,除Merge节点外一个节点可以被执行,如果它的输入全部被上游节点计算好了),直到可执行的队列为空。

如果没有control flow,计算图的执行是非常简单的:每个节点只被计算一次,这就是一个简单的DAG图,按照拓扑顺序(拓扑排序会吧?)一个一个执行节点就可以了。但是control flow引入了额外的复杂性,一个节点可以不被执行,也可以被执行任意次。excutor需要管理一个节点的多次运行,以及判断计算图的计算是否完成。

三元组(Node,frame,iteration Id)标识一个执行中的node,被称作taggedNode,tagedNode保存此次执行需要的输入值和计算输出值,这有点像程序和进程的关系。为了标识执行过程中的tensor,excutor 中的tensor由三元组(value,is_dead,tag)表示,其中value是真实的tensor,is_dead 标识当前是否是在一个不被执行的条件分支,tag 是一个string,唯一标识一个tensor(表示是某个tagNode的第几个输出)。tag 是send/recv pair 的传输key的一部分,以区分一对send/recv 的多个执行。

现在我们大概总结一下,excutor的执行过程。excutor 维持系列叫做frame的数据结构,每个frame维持一系列的iteration state 数据结果。frame 和 iteration status在执行的过程中动态地构建和销毁(new/del)。为每个tagednode维持一个peningcount, 表示还有几个输入值没有被计算出来。excutory 维持一个tagged node 工作队列,初始运行的时候,把一些source节点放到队列里,通过线程池,不断地从队列取出节点,计算输出,把输出传给需要这个输出作为输入的节点,改变该节点的pending count 值,如果为0,则可以把这个节点放到工作队列,直到工作队列为空。当然,遇到了control flow ops,需要特殊处理,比如说遇到Enter(name),如果是第一次遇到对应name的Enter节点,则需要初始化一个名字为name的子frame;如果是当前iteration i 中第一次遇到NextIteration,且传入的tensor is_dead 标识不为真,则初始化 iteration i+1,如果为真,不初始化下一个iteration,不往下传 tensor(nextiteration 节点截断whileloop循环)。

以下是计算节点的运算法则。

  • Switch(p, d) = (r​1​, r​2​) :

r​1​ = (value(d), p || is_dead(d), tag(d)) r​2​ = (value(d), !p || is_dead(d), tag(d)) 

 

  •  Merge(d​1​, d​2​) = r :

r = if is_dead(d​1​) then d​2​ else d​1 

  •   Enter(d, frame_name) = r :

value(r) = value(d) is_dead(r) = is_dead(d) tag(r) = tag(d)/frame_name/0 
 

  • Exit(d) = r :

value(r) = value(d) is_dead(r) = is_dead(d) tag(r) = tag​1​ where tag(d) = tag​1​/frame_name/n 

  • NextIteration(d) = d​1​ :

value(d​1​) = value(d) is_dead(d​1​) = is_dead(d) tag(d​1​) = tag​1​/frame_name/(n+1) where tag(d) = tag​1​/frame_name/n 
 

  • Op(d​1​, …, d​m​) = (r​1​, …, r​n​) :

value(r​i​) = Op.Compute(value(d​1​), …, value(d​m​)) if !is_dead(r​i​) is_dead(r​i​) = any(is_dead(d​1​), … is_dead(d​m​)), for all i tag(r​i​) = tag(d​1​), for all i 
 

分布式执行

Distributed Conditional Execution

一个cond 有几个原子控制流算子和其他算子构成,所以分布式执行的时候,可能会被分到不同的devices,如下图所示。

因为recv是一种source节点,所以无条件被执行。即使send节点在一个untaken branch,Recv 也会被执行。为了让Recv 知道这是一个untaken branch,Send节点会把is_dead 标志传给Recv,Recv会把这个节点传下去,直到某个Merge或Nextiteration节点。

Distributed While Loop 

同理,同属于一个whileloop的算子 也可能被分配到不同的devices;

在上面的这个栗子中,循环体中的Op 节点被分配到deviceB。如上图简单的分割子图,无法让Op知道它是属于一个whileloop,只它计算一次就结束了(Recv 触发Op,Recv 只执行一次,Op也就只被执行一次)。解决方案是重写计算图,在每一个子图加入control-loop 状态机。

虚线是控制边

一个标量tenser 0 作为控制循环的Enter输入。这些控制流循环提供了分布式执行whileloop的必要的信息。

让我们来模拟一下whileloop 执行0次的情况:

  • 在device A,Enter、Merge、P、Switch依次被执行。因为P不为真,Merge 会把is_dead 标志传给 send,send 传给device B 的rec节点。Exit 节点同时也可以运行,使得外层依赖这个exit 的节点可以被同时运行。p 的send 节点把 p 的值传给device B。
  • 在device B,Enter 触发 controlloop开始循环,依次执行Enter 和Merge。因为两个Rec 节点依赖Merge,Merge会触发这两节点的执行。连接switch 的 Recv 节点收到 p 的值为false,Next 节点收到is_dead 标识,终止这个循环。连接Op的Rec会收到 一个dead tensor,所以send会传一个dead tensor 回device A。在这个时间点device B 的这个子图当前f没有需要计算的节点,执行完毕。
  • 回到device A,连接Next的Recv收到一个dead tensor,循环也将终止。device A 没有需要执行的节点,计算结束。

 

反向传递求导

tensorflow 支持根据链式法则自动求导。用户可以使用tensorflow的算子构建一个神经网络,定义一个损失函数,tensorflow可以自动求导,构建反向传递子图。

所谓反向传递,就是求导的链式法则。tensorflow 自动求导算法反向便利原始的计算图中的ops,依此调用ops注册的梯度函数,一步步地构建反向传播图。算子的梯度函数构造计算该算子符号梯度的计算子图。简单地说,tensorflow自动求导算法就是根据链式法则把这些个子图组装在一起。反向传播时计算梯度函数可能需要用到前向传播原始算子的输入和输出,所以这些数值需要被保持下来,以供反向传播时使用。

接下来讲有控制流的反向传播。

直观地讲, 控制流算子的反向传递按照下面的规则: Exit 的梯度是Enter; switch 的梯度是Merge (对于cond而言) 或 NextIteration 接 Merge (对while_loop而言);  Merge的梯度是 Switch; NextIteration 的梯度是 Identity; and the gradient of Enter is Exit.  TensorFlow 支持对嵌套的条件结构和循环结构反向求导。

Backpropagation of Conditional 

直观的, cond(p, fn1, fn2)​ 的梯度就是cond(p, g_fn1, g_fn2)​ ,其中g_fn1​ 和 ​g_fn2​ 是fn1和fn2各自的导数。下图是一个cond 的梯度,我们暂时不考虑有whileloop 存在的情况,那会复杂一些。我们假设Op在cond的真值分支(也就是fn1)。

前向传播的Merge的后向传播梯度是 Switch,后向传播Switch和前向传播Switch 使用相等的bool值。梯度 g​y被传向Switch的两个分支。如果forward Switch的两个分支只有一个被用到了,反向传播的Merge会加入0 输入,确保输出一个活的梯度(is_dead等于false) 。0输入由一个swith控制,确保只有当p=false时,0才被输入。

Backpropagation of While Loop

直观地说,对于循环变量,While Loop梯度是如下形式的。

def pred(i, _): return i < N

while_loop(pred, g_body, [0] + g_vars) 。

其中,N是前向传播whileloop 迭代的次数,g_body是前向whileloop循环体的梯度函数。下图是大致的前向传播和对应的反向传播图。

 

对于循环常量。循环常量是指在whileloop中使用到,又没有包含在tf.whilleloop的第三个参数中的tensor,tensorflow 假定在迭代的过程中这些值不会被改变。如下面的代码片段。

import tensorflow as tf 
i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
n = tf.Variable(10)
b=tf.Variable(1)
def cond(a):
    return  a< n
def body(a):
    a = a + b
    return a

a= tf.while_loop(cond, body, [i])
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(a.eval())

 

其中 b和n是循环常量,i 是循环变量,简单的说,只有enter 与之对应的是循环常量,既有enter也有exit 、nextiteration与之对应的是循环变量。

对于循环常量,梯度是如下形式的循环体,

def pred(i, _): return i < N
acc = 0.0;
while (_pivot) {
  acc += grad;
}

由于和循环变量的梯度使用相同的pred,循环常量和循环变量的梯度整合在一个whileloop里。

到目前为止,我们的描述是非常简化的。比如说,N在构建计算图的时候是没法静态地确定的。更重要的是, G(Body) 可能会用到前向传播过程中计算出的值,值的”生产“和”使用“是一种LIFO的模式。

为了计算出N,需要在前向传播的whileloop中加入计算N的逻辑。

 

为了保存反向传播中需要用到的前向传播中计算出的值。在构建op时,需要检测op的输入值。如果当前需要构建的op是在一个反向传播的一个whileloop中,而某个输入值是来自前向传播的whileloop,则需要为这个输入值引入一个stack,并加入相应的stackpush/stackpop 算子,在前向传播whileloop 的每一个iteration把这个符号在这一轮的值入栈,在反向传播中按照相反的顺序出 栈。这个栈存在于前向和反向whileloop之外的Frame(所以要有2个Enter,如下图)

符号:在tensorflow 运行时执行计算图之前,tensor只是一个符号。

此外,因为iteration 可以并行,为了保证入栈的次序和出栈的次序。还需要加入一些控制依赖边。保证一个whileloop中,对同一个符号对应的变量的入栈或出栈操作,iteration i中的操作在 iteraion i+1 中的操作之前被执行。以及在whileloop嵌套的情况下,

同一个符号对应的变量的入栈或出栈操作,外层frame iteration i 里的frame里的操作在外层frame iteration i+1 里的framel里的操作之前。

Frame:frame 是一个whileloop的执行时体现,因为whileloop可能嵌套,一个嵌套在whileloop 里的whileloop 可能被执行多次,每次执行对应一个frame,类似于程序和进程的概念。

以上是个对控制流实现大概的梳理。控制流相关的代码,client 端主要在tensorfow/python/ops/control_flow_ops.py。目前C/C++ client API 并不支持对whileloop 求梯度。运行时的对控制流的处理主要是在 1、子图分割时,加入控制流循环,这部分实现在graph 的tensorflow/core/graph/graph_partion.cc ,partition函数会调用AddControFlow函数。2、对控制流原元语按照上文介绍的规则特殊处理,这部分代码实现是在excuter。感兴趣的同学可以看看源码,相信会有不小的收获。

 

 

 

 

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值