TensorFlow 获取张量形状的操作 tf.shape()、属性shape 及 方法get_shape() 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

1、tf.shape(tensor)

获取输入张量 input 的形状,以 1 维整数张量形式表示

https://tensorflow.google.cn/api_docs/python/tf/shape

tf.shape(
    input,
    name=None,
    out_type=tf.int32
)

参数:

input:张量或稀疏张量

name:可选参数,操作的名称

out_type:可选参数,指定输出张量的数据类型(int32 或 int64),默认是 tf.int32

返回:

指定 out_type 数据类型的张量

 

2、tensor.shape

张量的形状属性

返回一个表示该张量的形状 tf.TensorShape

对于每个操作,通过注册在 Op 中的形状推断函数来计算该张量的形状,形状表示的更多信息请参考 tf.TensorShape

不需要在会话中启动图 Graph 的情况下,张量的推断形状用来表示形状信息。该信息可以用来调试,提供早期的错误信息

在某些情况下,推断出的形状可能存在未知的维度。如果调用者有关于这些维度值的额外信息,就可以使用 Tensor.set_shape() 来拓展该推断的形状 

https://tensorflow.google.cn/api_docs/python/tf/Tensor

 

3、tensor.get_shape()

tensor.shape 的别名,即同样返回一个表示该张量的形状 tf.TensorShape 的方法

https://tensorflow.google.cn/api_docs/python/tf/Tensor

注意:tensor.get_shape(),不是get_shapes

 

三、实例

1、操作 tf.shape()、属性shape 及 方法get_shape() 的基本用法

>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v
<tf.Variable 'Variable:0' shape=(100, 100) dtype=float32_ref>

# tf.shape() 方法
>>> tf.shape(v)
<tf.Tensor 'Shape:0' shape=(2,) dtype=int32>

# shape 属性
>>> v.shape
TensorShape([Dimension(100), Dimension(100)])

# get_shape() 方法
>>> v.get_shape()
TensorShape([Dimension(100), Dimension(100)])

# 错误的用法举例
# 将属性当成方法
>>> v.shape()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'TensorShape' object is not callable

# 将方法当成属性
>>> v.get_shape
<bound method RefVariable.get_shape of <tf.Variable 'Variable_1:0' shape=(100, 100) dtype=float32_ref>>

2、操作 tf.shape() 及属性shape 与 方法get_shape() 的区别

(1)操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出

(2)方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape,tf.TensorShape 可以通过 as_list() 方法将形状转换为列表形式

https://tensorflow.google.cn/api_docs/python/tf/TensorShape

# 操作 tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
# 方法 get_shape() 和 属性 shape 都返回一个表示该张量形状的 tf.TensorShape

>>> import tensorflow as tf
>>> import tensorflow as tf
>>> v = tf.Variable(initial_value=tf.truncated_normal([100,100]))
>>> v
<tf.Variable 'Variable:0' shape=(100, 100) dtype=float32_ref>



# tf.shape() 方法
>>> tensor_shape = tf.shape(v)
>>> tensor_shape
<tf.Tensor 'Shape_1:0' shape=(2,) dtype=int32>

# tf.shape() 返回的 Tensor 没有 as_list() 方法,所以报错
>>> tensor_shape.as_list()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'Tensor' object has no attribute 'as_list'

# tf.shape() 则返回一个形状张量,必须在会话 Session 中才能打印输出
>>> with tf.Session() as sess:
...     print(sess.run(tensor_shape))
...
[100 100]




# shape 属性返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_1 = v.shape
>>> shapes_1
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_1 = shapes_1.as_list()
>>> shapes_list_1
[100, 100]




# get_shape() 方法返回一个 tf.TensorShape,可以通过 as_list() 方法将形状转换为列表形式
>>> shapes_2 = v.get_shape()
>>> shapes_2
TensorShape([Dimension(100), Dimension(100)])
>>> shapes_list_2 = shapes_2.as_list()
>>> shapes_list_2
[100, 100]

 

  • 8
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值