Tensorflow中的dynamic shape、static shape及reshape、set_shape

这个问题是在学习Tensorflow当中,reshape与set_shape的区别时引出的。

在学习Tensorflow的cifar10代码的时候,发现在处理cifar数据集的读入数据时,demo使用的是如下代码来处理读入数据而不是直接使用熟悉的reshape。

float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])

官方文档对此的解释是:

The tf.Tensor.set_shape method updates the static shape of a Tensor object, and it is typically used to provide additional shape information when this cannot be inferred directly. It does not change the dynamic shape of the tensor.

The tf.reshape operation creates a new tensor with a different dynamic shape.

由此引出了dynamic shape和static shape的概念。

Tensorflow在构建图的时候,tensor的shape被称为static(inferred);而在实际运行中,常常出现图中tensor的具体维数不确定而用placeholder代替的情况,因此static shape未必是已知的。tensor在训练过程中的实际维数被称为dynamic shape,而dynamic shape是一定的。

看如下例子:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))

在实际运行的时候,显示的结果是

unknown
[4]
[2,2]

第一个结果代表构建图时x1的shape,此时它是未知的,尽管在Session中运行了以后给出了[0,1,2,3]的传入值,但无法改变图中的x1的static shape。而后面的print也可以看到,随着传入值的不同,x1的dynamic shape是会变化的。

顺带一说get_shape()方法和tf.shape()的区别,get_shape()是tensor的方法,返回一个tuple,而tf.shape()则返回一个tensor。第三行的unknown直接代表了x1的shape数组,假如我们这里用:

print(tf.shape(x1))

就会看到显示的是一个Tensor,不过它的shape里有unknown。因此,我们在获取x1的shape的时候,要用tf.shape方法,让指针指向一个tensor,不然使用get_shape()指针就会指向一个tuple从而报错。

下面说下set_shape()和reshape()的区别。其实从官方说明中可以看出,这两个主要是适用场合的区别,前者用于更新图中某个tensor的shape,而后者则往往用于动态地创建一个新的tensor。

一个set_shape的典型用法如下:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x1.set_shape([22])
print(x1.get_shape())

sess = tf.Session()
#print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))

此时,运行结果为:

(2,2)
[2,2]

这代表了图中最开始没有shape的x1在使用了set_shape后,它的图中的信息已经改变了,如果取消掉注释就会报错,因为我们传入了和图不符合的参数。

reshape的典型用法则是这样:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x2 = tf.reshape(x1, [2,2])
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x2), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x2), feed_dict={x1:[[0,1],[2,3]]}))

此时运算结果为:

(2,2)
[2,2]
[2,2]

即它并不是想改变图,而只是想创造一个新的tensor以供我们使用。

但是reshape能否和set_shape有着相同的用法,即用来改变图?我们试着修改上面的代码:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x1 = tf.reshape(x1, [2,2]) # use tf.reshape()
print(tf.shape(x1))

sess = tf.Session()
#print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))

经测试,reshape后x1的shape也发生了变化,注释不取消仍然会有报错现象。

那么set_shape和reshape的用法是否完全一样呢?还是有一定差别的。reshape可以改变原有tensor的shape,而set_shape只能更新信息没办法直接改变值,可以参考下面的程序:

import tensorflow as tf
x1 = tf.Variable([[0, 1], [2, 3]])
print(x1.get_shape())

x1 = tf.reshape(x1, [4, 1]) # if we use x1.set_shape([4, 1]),the program cannot run
print(x1.get_shape())

最后总结一下吧,reshape应用场合比较广泛,当我们需要创建新的tensor或者动态地改变原有tensor的shape的时候可以使用;而当我们只是想更新图中某个tensor的shape或者补充某个tensor的shape信息可以使用set_shape来进行更新。

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值