最快的学习方式是:
自己手动建立一层,看一下输入输出的维度。比如:
hiden = layers.LSTM(2)
x = tf.constant(np.arange(45).reshape(3,5,3),dtype=tf.float32)
y = hiden(x)
发现,LSTM的输入一定是3维的,分别是(Batch, 时间戳,属性维度)。输出一定是2维的,分别是(Batch, lstm_units)。所以,LSTM可以直接连接Dense层。但是Conve层必须先flat才能Dense层。
self.flat = layers.Flatten()