参考资料:https://blog.csdn.net/chenxieyy/article/details/53020760
1,tf.shape(a)和a.get_shape()比较
相同点:都可以得到tensor a的尺寸
不同点:tf.shape()中a 数据的类型可以是tensor, list, array
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)
2.获取tensor的维度x.get_shape().with_rank(3),3为tensor维度,返回的还是一个元组和不加的时候一样,如果加了with_rank但维数不对会报错。如果维度错误也会报错
3.tf.stack()这是一个矩阵拼接的函数(里面传两个值,第一个是要拼接的列表,第二个是从哪个方向拼接,0表示竖直方向,把列表中的对象逐个添加到一个列表中,1表示横着拼接,把元素中的对应值拼接完加到列表中,之多元素对应值都拼接完。),
list_ = [[1, 2, 3], [4, 5, 6]]
print(type(list_))
sess=tf.Session()
a = tf.stack(list_, 0)
print sess.run(a)
a = tf.stack(list_, 1)
print sess.run(a)
[[1 2 3]
[4 5 6]]
[[1 4]
[2 5]
[3 6]]
4.tf.unstack()则是一个矩阵分解的函数(第一个值传矩阵,第二个值传分解的方向,如果为0将矩阵以竖直方向存放,如果为1将矩阵个列拼接为1个元素,然后存成列表)。
[array([1, 2, 3], dtype=int32), array([4, 5, 6], dtype=int32)]
[array([1, 4], dtype=int32), array([2, 5], dtype=int32), array([3, 6], dtype=int32)]
2,例子:
import tensorflow as tf
import numpy as np
x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]]
z=np.arange(24).reshape([2,3,4])
sess=tf.Session()
# tf.shape()
x_shape=tf.shape(x)
y_shape=tf.shape(y)
z_shape=tf.shape(z)
print sess.run(x_shape)
print sess.run(y_shape)
print sess.run(z_shape)
#a.get_shape()
x_shape=x.get_shape()
print x_shape
x_shape=x.get_shape().as_list()
print x_shape
z_shape =z.get_shape()
print z_shape
# y_shape=y.get_shape() # AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape() # AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'
而list,numpy的对象是没有getshape的属性的,只有tensor有getshape.
getshape()返回一个元组,通过aslist()方法可以把一个元组转为list
x_shape=x.get_shape()
print x_shape
x_shape=x.get_shape().as_list()
(2, 3)
[2, 3]