tensorflow之中的get_shape()函数好tf.shape()函数输出的形状有很大的区别
这里我们以一个小例子来说明具体用法的不同
input_ids = keras.layers.Input(shape=(None,),dtype='int32',name="token_ids")
input_shape = input_ids.get_shape()
print('input_shape1 = ')
print(input_shape)
maxlen = input_shape[1]
print(maxlen==None)
input_shape = tf.shape(input=input_ids)
print('input_shape2 = ')
print(input_shape)
maxlen = input_shape[1]
print(maxlen==None)
对应的输出内容如下:
可以看出get_shape()函数和tf.shape()函数的区别:get_shape()函数输出的内容为一个list类型的数组,而tf.shape()输出的为一个tensor类型的数据
所以这里的
maxlen = input_shape[1]
在get_shape()之后从数组中直接得到对应的None的值,而
maxlen = input_shape[1]
得到的是一个对应的tensor值
KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.int32, name=None), inferred_value=[None], name='tf.__operators__.getitem_1/strided_slice:0', description="created by layer 'tf.__operators__.getitem_1'")
所以这里tensor判断是否为None的时候就会报false。
但是在
output_shape = [batch_size,seq_len,
self.num_attention_heads,self.size_per_head]
output_tensor = K.reshape(input_tensor,output_shape)
这里面如果放入常规的None,None的时候,K.reshape操作会报错,所以这里放入的内容必须为KerasTensor类型的数据。