tf.reshape()函数介绍和示例
tf.reshape(tensor, shape, name=None)
释义:将张量 tensor 的形状改为 shape
注意:shape 设置为 -1 的位置,表示不用设置这一维的大小,函数自动进行计算(只能存在一个 -1)
示例1:
import tensorflow as tf
X = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name=None)
Y = tf.reshape(X, (3, 2)) # (2, 3) --> (3, 2)
with tf.Session() as sess:
print('原Tensor:\n', sess.run(X))
print('='*30)
print('shape(3, 2):\n', sess.run(Y))
原Tensor:
[[1. 2. 3.]
[4. 5. 6.]]
==============================
shape(3, 2):
[[1. 2.]
[3. 4.]
[5. 6.]]
示例2:
X = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name=None)
Y1 = tf.reshape(X, (-1, 6)) # (2, 3) --> (1, 6)
Y2 = tf.reshape(X, (6, -1)) # (2, 3) --> (6, 1)
with tf.Session() as sess:
print('原Tensor:\n', sess.run(X))
print('='*30)
print('shape(1, 6):\n', sess.run(Y1))
print('='*30)
print('shape(6, 1):\n', sess.run(Y2))
原Tensor:
[[1. 2. 3.]
[4. 5. 6.]]
==============================
shape(1, 6):
[[1. 2. 3. 4. 5. 6.]]
==============================
shape(6, 1):
[[1.]
[2.]
[3.]
[4.]
[5.]
[6.]]