函数原型:
tf.keras.layers.TimeDistributed(
layer, **kwargs
)
除了batch_size以外,第一个维度被认为是时间维度,在进行卷积或其他操作的时候,batch_size和时间维度保持不变,对后面的维度进行处理,所以至少应该为3维。
比如(32, 10, 128, 128, 3),batch_size = 32, 包含10个时间步长的128*128的RGB图片。
inputs = tf.keras.Input(shape=(10, 128, 128, 3))
conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
outputs.shape
输出:(None, 10, 126, 126, 64)
inputs = tf.keras.Input(shape=(10, 16, 16, 3))
x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(inputs)
print(outputs.shape)
输出:(None, 10, 126, 126, 64)