@[tf.scan()函数使用介绍]
tf.scan()帮助文档
在python交互式解释器下输入help(tf.scan)查看帮助文档这句话说明了scan的作用。他将张量elems的第一维元素list作函数计算,带入至fn,作为参数,累加作用,直到遍历完elems的所有一维元素。作用类似于functools里面的reduce函数。
Demo
1维数组:
scan会取出nums每一个一维数据,在这个代码中就是首先取出1,然后和initializer=0进行相加,结果为1;然后在取出x=2,与上一步结果相加,结果为3;以此类推,下一步结果为3+3=6,最终的结果为:
[ 1 3 6 10 15 21]
这个例子很easy.
2维数组:
这里需要注意的是,我们在传入Initializer的时候,需要自己判定应该传入什么类型的初始化值。在这里需要传入一个(6,)的向量。为什么?
tf.scan()会选取elems的每一个第一个维度的子元素,那么在这里就是[1,2,3,4,5,6]和[1,2,3,4,5,6]两个。然后累加作用于lambda表达式中,也就是说[1,2,3,4,5,6]+?,由于不允许一个向量+一个数字的形式出现,所以initialize只能传(6,)的向量。
计算过程:
x = [1,2,3,4,5,6],a = [1,1,1,1,1,1],x+a=[2,3,4,5,6,7]
a = [2,3,4,5,6,7],x=[1,2,3,4,5,6],x+a=[3,5,7,9,11,13]
最终输出:
[
[2,3,4,5,6,7],
[3,5,7,9,11,13]
]
3维数组:只给出Demo,自己分析原因
输出为: