tensorflow中的静态维度和动态维度

参考:

1. TensorFlow: Shapes and dynamic dimensions一文中,对张量的静态和动态维度做了描述。

  • 使用tf.get_shape()获取静态维度
  • 使用tf.shape获取动态维度
    如果你的placeholder输入的维度都是固定的情况下,使用get_shape()。但是很多情况下,我们希望想训练得到的网络可以用于任意大小的图像,这时你的placeholder就的输入维度都是[None,None,None,color_dim]这样的,在这种情况下,后续网络中如果需要得到tensor的维度,则需要使用tf.shape。

2 . https://blog.csdn.net/LoseInVain/article/details/78762739

3. tensor.shap.as_list()返回静态维度。tf.shape(tensor)返回动态维度

参考bert->modeling->get_shape_list()函数

def get_shape_list(tensor, expected_rank=None, name=None):
  """Returns a list of the shape of tensor, preferring static dimensions.

  Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
      specified and the `tensor` has a different rank, and exception will be
      thrown.
    name: Optional name of the tensor for the error message.

  Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
  """
  if name is None:
    name = tensor.name

  if expected_rank is not None:
    assert_rank(tensor, expected_rank, name)
  # tensor.shape.as_list()返回静态维度
  shape = tensor.shape.as_list()

  non_static_indexes = []
  for (index, dim) in enumerate(shape):
    if dim is None:
      non_static_indexes.append(index)

  if not non_static_indexes:
    return shape

  dyn_shape = tf.shape(tensor)
  for index in non_static_indexes:
    shape[index] = dyn_shape[index]
  return shape

4. 示例:

import tensorflow as tf
a = tf.placeholder(dtype=tf.int32, shape=[None, None])
# static dimensions
a_static = a.shape.as_list()
print(a_static)
# dynamic dimensions
b = tf.shape(a)
print(b)
dynamic_index = []
for i, v in enumerate(a_static):
    if v is None:
        dynamic_index.append(i)
for i in dynamic_index:
    a_static[i] = b[i]
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run([a_static], feed_dict={a:[[1,2],[3,4]]}))

#
[None, None]
Tensor("Shape:0", shape=(2,), dtype=int32)
[[2, 2]]

 若稍微修改下,会报错:

import tensorflow as tf
a = tf.placeholder(dtype=tf.int32, shape=[None, 2])
# static dimensions
a_static = a.shape.as_list()
print(a_static)
# dynamic dimensions
b = tf.shape(a)
print(b)
dynamic_index = []
for i, v in enumerate(a_static):
    if v is None:
        dynamic_index.append(i)
for i in dynamic_index:
    a_static[i] = b[i]
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run([a_static], feed_dict={a:[[1,2],[3,4]]}))



#
[None, 2]
Tensor("Shape:0", shape=(2,), dtype=int32)
TypeError: Can not convert a int into a Tensor or Operation.
TypeError: Fetch argument 2 has invalid type <class 'int'>, must be a string or Tensor. (Can not convert a int into a Tensor or Operation.)

原因 : a_static[1]是整数,不是tensor

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值