map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None)
map_fn 是 tf 中的一个高级函数,其中 fn 是一个可调用函数, elems 是需要处理的 tensors, 可以是一个也可以是一个 tuple,tf 会在每个 Tensor 的第一维度进行展开,然后执行 map 操作,也就是对展开后的每个元素执行 fn 函数。dtype 是 fn 函数的输出类型,如果 fn 函数返回的类型和 elems 的类型不同,则必须提供 dtype 参数。返回值是一个或多个 tensors,每个 Tensor 的第 i 维是每个 fn() 返回的第 i 维拼起来的,这么说有点抽象,请看下面的例子。
import tensorflow as tf
def fun(x):
a, b = x
print("shape in fun(): ", a.shape, b.shape, a.dtype, b.dtype)
# shape (3,1)
return [1, 2, [1,2,3,4,5], [6,7,8]]
a = np.arange(12).reshape((4,3,1))
b = np.arange(1, 1+12).reshape(4,3,1)
print(a.shape, b.shape)
out_dtype = [ tf.int32, tf.int32, [tf.int32]*5, [tf.int32]*3 ]
result = tf.map_fn(fun, elems=(a, b), dtype=out_dtype, parallel_iterations=a.shape[0])
with tf.Session() as sess:
res = sess.run(result)
for e in res:
print(len(e), e)
print("*****************")
fun 函数返回一个 list, [1, 2, [1,2,3,4,5], [6,7,8]]
,a 和 b 的 shape 都是 (4,3,1)。调用 map_fn 时传入 elems=(a,b)
,
实际执行时,tf 会将 a 和 b 都从第一维度展开,也就是分成了 4 个任务,每个任务的参数是 (a’, b’), 这时 a’ 和 b’的 shape 都是 (3,1),然后将每个 (a’,b’) 应用到 fun() 函数中,那么就会有 4 组返回值,每个任务都返回 [1, 2, [1,2,3,4,5], [6,7,8]]
,但其实 map_fn() 最终的格式并不是简单的把 4 组返回拼接成一个 list,如 [[1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]]]
这样,而是把每组返回值的第 i 维合并平成 map_fn() 的第 i 维,如 map_fn()[0] = [fun()[0], fun()[0], ..., fun()[0]]
。
这个例子中 map_fn() 返回值格式是这样的:
[[1, 1, 1, 1],
[2, 2, 2, 2],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5]],
[[6, 6, 6, 6],
[7, 7, 7, 7],
[8, 8, 8, 8]]
]
完整的代码运行结果如下:
(4, 3, 1) (4, 3, 1)
shape in fun(): (3, 1) (3, 1) <dtype: 'int64'> <dtype: 'int64'>
4 [1 1 1 1]
*****************
4 [2 2 2 2]
*****************
5 [array([1, 1, 1, 1], dtype=int32), array([2, 2, 2, 2], dtype=int32), array([3, 3, 3, 3], dtype=int32), array([4, 4, 4, 4], dtype=int32), array([5, 5, 5, 5], dtype=int32)]
*****************
3 [array([6, 6, 6, 6], dtype=int32), array([7, 7, 7, 7], dtype=int32), array([8, 8, 8, 8], dtype=int32)]
*****************
如果把 dtype 参数设置成 None,就会报如下错误:
ValueError: The two structures don't have the same number of elements. First structure: (tf.int64, tf.int64), second structure: [1, 2, [1, 2, 3, 4, 5], [6, 7, 8]].
因为 fun() 的参数是 x, x 是一个 ‘tuple(a, b)’,x 的 dtype 是 (tf.int64, tf.int64)
, 而 fun() 函数返回结构是 [1, 2, [1, 2, 3, 4, 5], [6, 7, 8]]
,所以必须按照这个结构提供返回的 dtype 格式。