从名字来看,这2个接口都是可以获取tensor的shape的,但有明显的区别,具体为:
1. tf.shape返回的是tensor,而tf.get_shape返回的是一个元组,所以前者想要获取tensor具体的shape结果需要sess.run才行;
2. 对tf.placeholder占位符来说,如果shape设置的其中某一个是None,那么对于tf.shape,sess.run会报错,而tf.get_shape不会,它会在None位置显示“?”表示此位置的shape暂时未知。
a = tf.Variable(tf.constant(1.5, dtype=tf.float32, shape=[1,2,3,4,5,6,7]), name='a')
b = tf.placeholder(dtype=tf.int32, shape=[None, 3], name='b')
s1 = tf.shape(a)
s2 = a.get_shape()
print (s1) # Tensor("Shape:0", shape=(7,), dtype=int32)
print (s2) # 元组 (1, 2, 3, 4, 5, 6, 7)
s11 = tf.shape(b)
s21 = b.get_shape()
print (s11) # Tensor("Shape_1:0", shape=(2,), dtype=int32)
print (s21) # 因为第一位设置的是None,所以这里的第一位显示问号表示暂时不确认 (?, 3)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(s1)) # [1 2 3 4 5 6 7]
print (sess.run(s11))
# InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'b' with dtype int32
# [[Node: b = Placeholder[dtype=DT_INT32, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
在代码中,经常用到对tensor的reshape操作,借助这2个接口可以灵活的进行reshape,方法如下:
a = tf.Variable(tf.constant(1.5, dtype=tf.float32, shape=[4, 6]), name='a')
b = tf.placeholder(dtype=tf.int32, shape=[4, 60], name='b')
s1 = tf.shape(a)
s2 = a.get_shape().as_list() # 注意,这里需要将元组转成list,才能在下面的reshape中使用,否则会报错:
# TypeError: Expected binary or unicode string, got Dimension(6)
c = tf.reshape(b, shape=[s1[1], -1])
print (c) # Tensor("Reshape:0", shape=(?, ?), dtype=int32)
d = tf.reshape(b, shape=[s2[1], -1])
print (d) # Tensor("Reshape_1:0", shape=(6, 40), dtype=int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
after_reshape = sess.run(c, feed_dict={b:np.ones(shape=[4, 60], dtype=np.int32)})
print (after_reshape.shape) # (6, 40)
after_reshape = sess.run(d, feed_dict={b:np.ones(shape=[4, 60], dtype=np.int32)})
print (after_reshape.shape) # (6, 40)