1. lstm等维度(三维)
下面输入维度是(32,10,8),设置cell=4,最后输出的维度为(32,4)
inputs = tf.random.normal([32, 10, 8])
print(inputs) # shape=(32, 10, 8)
lstm = tf.keras.layers.LSTM(4)
output = lstm(inputs)
print(output.shape) #(32, 4)
注意:
LSTM层后面接LSTM层的话,设置return_sequences = True。
如果接全连接层的话,设置return_sequences = False。
2. CNN输入维度(四维)
3.全连接输入维度(二维即可)