参考:
1. TensorFlow: Shapes and dynamic dimensions一文中,对张量的静态和动态维度做了描述。
- 使用tf.get_shape()获取静态维度
- 使用tf.shape获取动态维度
如果你的placeholder输入的维度都是固定的情况下,使用get_shape()。但是很多情况下,我们希望想训练得到的网络可以用于任意大小的图像,这时你的placeholder就的输入维度都是[None,None,None,color_dim]这样的,在这种情况下,后续网络中如果需要得到tensor的维度,则需要使用tf.shape。
2 . https://blog.csdn.net/LoseInVain/article/details/78762739
3. tensor.shap.as_list()返回静态维度。tf.shape(tensor)返回动态维度
参考bert->modeling->get_shape_list()函数
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if name is None:
name = tensor.name
if expected_rank is not None:
assert_rank(tensor, expected_rank, name)
# tensor.shape.as_list()返回静态维度
shape = tensor.shape.as_list()
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None:
non_static_indexes.append(index)
if not non_static_indexes:
return shape
dyn_shape = tf.shape(tensor)
for index in non_static_indexes:
shape[index] = dyn_shape[index]
return shape
4. 示例:
import tensorflow as tf
a = tf.placeholder(dtype=tf.int32, shape=[None, None])
# static dimensions
a_static = a.shape.as_list()
print(a_static)
# dynamic dimensions
b = tf.shape(a)
print(b)
dynamic_index = []
for i, v in enumerate(a_static):
if v is None:
dynamic_index.append(i)
for i in dynamic_index:
a_static[i] = b[i]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([a_static], feed_dict={a:[[1,2],[3,4]]}))
#
[None, None]
Tensor("Shape:0", shape=(2,), dtype=int32)
[[2, 2]]
若稍微修改下,会报错:
import tensorflow as tf
a = tf.placeholder(dtype=tf.int32, shape=[None, 2])
# static dimensions
a_static = a.shape.as_list()
print(a_static)
# dynamic dimensions
b = tf.shape(a)
print(b)
dynamic_index = []
for i, v in enumerate(a_static):
if v is None:
dynamic_index.append(i)
for i in dynamic_index:
a_static[i] = b[i]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([a_static], feed_dict={a:[[1,2],[3,4]]}))
#
[None, 2]
Tensor("Shape:0", shape=(2,), dtype=int32)
TypeError: Can not convert a int into a Tensor or Operation.
TypeError: Fetch argument 2 has invalid type <class 'int'>, must be a string or Tensor. (Can not convert a int into a Tensor or Operation.)
原因 : a_static[1]是整数,不是tensor