Tensorflow map_fn
flyfish
import numpy as np
import tensorflow as tf
elems = np.array([1, 2, 3, 4, 5, 6])
squares = tf.map_fn(lambda x: x * x, elems)
sess = tf.InteractiveSession()
print(squares.eval())
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = tf.map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]
print(alternate.eval())
elems = np.array([1, 2, 3])
alternates = tf.map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]
print(alternates[0].eval())
print(alternates[1].eval())
#[ 1 4 9 16 25 36]
#[-1 2 -3]
#[1 2 3]
#[-1 -2 -3]