tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
作用:map on the list of tensors unpacked from elems
on dimension 0.
参数:
fn
: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems
. Its output must have the same structure as dtype
if one is provided, otherwise it must have the same structure as elems
.
elems
: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to fn
.
dtype
: (optional) The output type(s) of fn
. If fn
returns a structure of Tensors differing from the structure of elems
, then dtype
is not optional and must have the same structure as the output of fn
.
parallel_iterations
: (optional) The number of iterations allowed to run in parallel.
back_prop
: (optional) True enables support for back propagation.
swap_memory
: (optional) True enables GPU-CPU memory swapping.
infer_shape
: (optional) False disables tests for consistent output shapes.
name
: (optional) Name prefix for the returned tensors.
官网例子:
1.
elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]
elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]