tf.map_fn() 函数使用说明及示例

map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
           swap_memory=False, infer_shape=True, name=None)

map_fn 是 tf 中的一个高级函数,其中 fn 是一个可调用函数, elems 是需要处理的 tensors, 可以是一个也可以是一个 tuple,tf 会在每个 Tensor 的第一维度进行展开,然后执行 map 操作,也就是对展开后的每个元素执行 fn 函数。dtype 是 fn 函数的输出类型,如果 fn 函数返回的类型和 elems 的类型不同,则必须提供 dtype 参数。返回值是一个或多个 tensors,每个 Tensor 的第 i 维是每个 fn() 返回的第 i 维拼起来的,这么说有点抽象,请看下面的例子。

import tensorflow as tf
def fun(x):
    a, b = x
    print("shape in fun(): ", a.shape, b.shape, a.dtype, b.dtype)
    # shape (3,1)
    return [1, 2, [1,2,3,4,5], [6,7,8]]

a = np.arange(12).reshape((4,3,1))
b = np.arange(1, 1+12).reshape(4,3,1)
print(a.shape, b.shape)

out_dtype = [ tf.int32, tf.int32, [tf.int32]*5, [tf.int32]*3 ]
result = tf.map_fn(fun, elems=(a, b), dtype=out_dtype, parallel_iterations=a.shape[0])
with tf.Session() as sess:
    res = sess.run(result)
    for e in res:
        print(len(e), e)
        print("*****************")

fun 函数返回一个 list, [1, 2, [1,2,3,4,5], [6,7,8]],a 和 b 的 shape 都是 (4,3,1)。调用 map_fn 时传入 elems=(a,b)
实际执行时,tf 会将 a 和 b 都从第一维度展开,也就是分成了 4 个任务,每个任务的参数是 (a’, b’), 这时 a’ 和 b’的 shape 都是 (3,1),然后将每个 (a’,b’) 应用到 fun() 函数中,那么就会有 4 组返回值,每个任务都返回 [1, 2, [1,2,3,4,5], [6,7,8]],但其实 map_fn() 最终的格式并不是简单的把 4 组返回拼接成一个 list,如 [[1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]], [1, 2, [1,2,3,4,5], [6,7,8]]] 这样,而是把每组返回值的第 i 维合并平成 map_fn() 的第 i 维,如 map_fn()[0] = [fun()[0], fun()[0], ..., fun()[0]]

这个例子中 map_fn() 返回值格式是这样的:

[[1, 1, 1, 1],
 [2, 2, 2, 2],
 [[1, 1, 1, 1],
  [2, 2, 2, 2],
  [3, 3, 3, 3],
  [4, 4, 4, 4],
  [5, 5, 5, 5]],
 [[6, 6, 6, 6],
  [7, 7, 7, 7],
  [8, 8, 8, 8]]
]

完整的代码运行结果如下:

(4, 3, 1) (4, 3, 1)
shape in fun():  (3, 1) (3, 1) <dtype: 'int64'> <dtype: 'int64'>
4 [1 1 1 1]
*****************
4 [2 2 2 2]
*****************
5 [array([1, 1, 1, 1], dtype=int32), array([2, 2, 2, 2], dtype=int32), array([3, 3, 3, 3], dtype=int32), array([4, 4, 4, 4], dtype=int32), array([5, 5, 5, 5], dtype=int32)]
*****************
3 [array([6, 6, 6, 6], dtype=int32), array([7, 7, 7, 7], dtype=int32), array([8, 8, 8, 8], dtype=int32)]
*****************

如果把 dtype 参数设置成 None,就会报如下错误:

ValueError: The two structures don't have the same number of elements. First structure: (tf.int64, tf.int64), second structure: [1, 2, [1, 2, 3, 4, 5], [6, 7, 8]].

因为 fun() 的参数是 x, x 是一个 ‘tuple(a, b)’,x 的 dtype 是 (tf.int64, tf.int64), 而 fun() 函数返回结构是 [1, 2, [1, 2, 3, 4, 5], [6, 7, 8]],所以必须按照这个结构提供返回的 dtype 格式。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值