def flatten(x):
"""
Input:
- TensorFlow Tensor of shape (N, D1, ..., DM)
Output:
- TensorFlow Tensor of shape (N, D1 * ... * DM)
"""
N = tf.shape(x)[0]
return tf.reshape(x, (N, -1))
def test_flatten():
# Clear the current TensorFlow graph.
tf.reset_default_graph()
# Stage I: Define the TensorFlow graph describing our computation.
# In this case the computation is trivial: we just want to flatten
# a Tensor using the flatten function defined above.
# Our computation will have a single input, x. We don't know its
# value yet, so we define a placeholder which will hold the value
# when the graph is run. We then pass this placeholder Tensor to
# the flatten function; this gives us a new Tensor which will hold
# a flattened view of x when the graph is run. The tf.device
# context manager tells TensorFlow whether to place these Tensors
# on CPU or GPU.
with tf.device(device):
x = tf.placeholder(tf.float32)
x_flat = flatten(x)
# At this point we have just built the graph describing our computation,
# but we haven't actually computed anything yet. If we print x and x_flat
# we see that they don't hold any data; they are just TensorFlow Tensors
# representing values that will be computed when the graph is run.
print('x: ', type(x), x)
print('x_flat: ', type(x_flat), x_flat)
print()
# We need to use a TensorFlow Session object to actually run the graph.
with tf.Session() as sess:
# Construct concrete values of the input data x using numpy
x_np = np.arange(24).reshape((2, 3, 4))
print('x_np:\n', x_np, '\n')
# Run our computational graph to compute a concrete output value.
# The first argument to sess.run tells TensorFlow which Tensor
# we want it to compute the value of; the feed_dict specifies
# values to plug into all placeholder nodes in the graph. The
# resulting value of x_flat is returned from sess.run as a
# numpy array.
x_flat_np = sess.run(x_flat, feed_dict={x: x_np})
print('x_flat_np:\n', x_flat_np, '\n')
# We can reuse the same graph to perform the same computation
# with different input data
x_np = np.arange(12).reshape((2, 3, 2))
print('x_np:\n', x_np, '\n')
x_flat_np = sess.run(x_flat, feed_dict={x: x_np})
print('x_flat_np:\n', x_flat_np)
test_flatten()