参考: https://keras.io/examples/vision/conv_lstm/
深度学习还是自己动手跑一下,能更深入的了解。
准备数据
下载的数据是moving-mnist
整体形状为20 * 10000 * 64 * 64
10000个样本,每个20帧,为增加特征或channel维度,增加最后为1
# Swap the axes representing the number of frames and number of data samples.
dataset = np.swapaxes(dataset, 0, 1)
# We'll pick out 1000 of the 10000 total examples and use those.
dataset = dataset[:1000, ...]
# Add a channel dimension since the images are grayscale.
dataset = np.expand_dims(dataset, axis=-1)
转化为 1000 * 20 * 64 * 64 *1
构建样本
def create_shifted_frames(data):
x = data[:, 0 : data.shape[1] - 1, :, :]
y = data[:, 1 : data.shape[1], :, :]
return x, y
Training Dataset Shapes: (900, 19, 64, 64, 1), (900, 19, 64, 64, 1)
Validation Dataset Shapes: (100, 19, 64, 64, 1), (100, 19, 64, 64, 1)
准备模型
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(5, 5),
padding="same",
return_sequences=True,
activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(3, 3),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(1, 1),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.Conv3D(
filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
)(x)
# Next, we will build the complete model and compile it.
model = keras.models.Model(inp, x)
预测
关键是预测这里,我只看代码没有完全理解,才找着跑了下。
训练是输入: None * 19 * 64 * 64 * 1
输出和输入是一样的
但是预测时采用了逐帧预测的办法,这就是所谓的encoder-forecasting结构
example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]
# Pick the first/last ten frames from the example.
frames = example[:10, ...]
original_frames = example[10:, ...]
# Predict a new set of 10 frames.
for _ in range(10):
# Extract the model's prediction and post-process it.
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
frames = np.concatenate((frames, predicted_frame), axis=0)
选取的单个frame: 10 * 64 * 64 * 1
转换为batch样本: 1 * 10 * 64 * 64 * 1
预测结果同理,忽视样本唯独后: 10 * 64 * 64 * 1
选取最后一帧: 1 * 64 * 64 * 1
加到frame之后: 11 * 64 * 64 *1
loop逐步预测未来10帧。
其中模型能够采用可变输入,应该是其中关键
SAConvLSTM
再来看一下一些改版中的预测结构
输入: None * 12 * Lon * Lat * 1
输出: None * 26 * Lon * Lat * 1
Target: None * 24
输入、输出共38帧
训练时:采用teacher forcing或 scheduler sampling(真实与预测的比例), 总步数为37。首先第一帧进入网络得到第2桢,这样依次往后迭代。
训练时,全部采用真实帧输入。到预测帧时,采用真实或上一步的预测,最后输出预测帧。
所以网络的输入为一帧,输出也是一帧,也就是