tf.shape()
shape(
input,
name=None,
out_type=tf.int32
)
返回一个代表 input 的 shape 的 1-D tensor.
也就是说获得shape的信息,需要用sess.run().
>>> import tensorflow as tf
>>> a = tf.constant([[1, 2, 3], [4, 5, 6]])
>>> a # 2-D tensor
<tf.Tensor 'Const:0' shape=(2, 3) dtype=int32>
>>> a_shape = tf.shape(a)
>>> a_shape # 1-D tensor, 长度和a的维度一样
<tf.Tensor 'Shape:0' shape=(2,) dtype=int32>
>>> sess = tf.Session()
>>> a_shape_eval = sess.run(a_shape)
>>> a_shape_eval
array([2, 3], dtype=int32)
tensor.get_shape()
返回一个tf.TensorShape的类,代表当前tensor的shape
不需要运行计算图就可以获得shape的信息.
使用tensor.shape也可以得到当前tensor的shape
>>> import tensorflow as tf
>>> a = tf.constant([[1, 2, 3], [4, 5, 6]])
>>> a # 2-D tensor
<tf.Tensor 'Const:0' shape=(2, 3) dtype=int32>
>>> a.shape # 类型是 tf.TensorShape
TensorShape([Dimension(2), Dimension(3)])
>>> a.get_shape() # 功能和tensor.shape一样
TensorShape([Dimension(2), Dimension(3)])
>>> a.get_shape().as_list() # 转化为list
[2, 3]