今天介绍tensor.get_shape()的用法
tensor.get_shape()
返回张量的维度,用一个元组来表示。
例如:
x=tf.constant([[[1,1,1],[2,2,2],[3,3,3],[4,4,4]],
[[1,1,1],[2,2,2],[3,3,3],[4,4,4]],
[[1,1,1],[2,2,2],[3,3,3],[4,4,4]]])
#容易看出,这是一个三维的张量(tensor)
x.get_shape()
x.get_shape()[1:]
x.get_shape()[1:].num_elements()
以上三个用法分别是什么意思呢?
第一个返回:(3,4,3)
用张量来说的话就是三个四行三列的张量,分别依次对应3,4,3
第二个返回(4,3)
就是切片操作,从第一个元素开始往后全部取来。
注意切片[a:b] 中,包含a而不包含b的元素。
第三个返回12
即:(4,3)中的元素个数,也就是12。
有什么用呢?再卷积神经网络中,用此种方法可以在卷积层拉伸为全连接层时,免于人工计算应该拉成多长。
features=last_conv_layer.get_shape()[1:].num_elements()
#第一个维度是数据的个数(batch_size组)
result =last_conv_layer.reshape([-1,features])
#拉伸成全连接层,-1是不知道有多少组数据的意思。