简介:现有的预测模型越来考虑时间和空间的相关性,统称为时空预测。这种预测模型往往比简单的序列模型(例如RNN、LSTM、GRU及其变体)、Transformer等效果更好。我使用Keras实现了该GCN-LSTM代码,因为Keras相比于torch更容易入手和理解。我实现了一个基于Keras的GCN网络层,可以像Keras中调用Dense、LSTM等网络层一样随时调用这个层。需要电脑安装tensorflow和keras。keras的版本为2.3。tensorflow的版本为2.1。
1、模型的输入数据形状
由于是时空数据,模型的输入形状为:[批次大小,时间步长,节点个数,维度数量]。具体的来说,就是[batchsize,node,time,dim]。批次大小表示一次性喂给模型的样本数量、节点个数就是图中的实体节点数量、时间步长就是每个节点记录的一段时间序列数据(可以是单维度、多维度的)的长度、维度数量就是每个时刻点记录的变量个数。
如下图所示,输入为[None,30,11,6],表示批次小自动调节,输入的时间段为30步,一共有11个节点、每个时刻有6个特征。具体思路为:首先对于每个时刻,进行图卷积,实现节点之间的信息传递;其次,使用LSTM压缩所有时刻的信息到一个一维张量。
2、邻接矩阵的构建和预处理
邻接矩阵往往需要根据所研究的具体场景决定,比如每辆车看作是节点、每条路段看作是节点、每个行人看作是节点、或者每个通道特征图看作是节点也是可以的。
对于边来说,就是考虑节点之间的关系。计算GCN计算公式中归一化邻接矩阵D的代码如下。
def getAdj(A):
"""
A:自带自环的矩阵
"""
s = np.sum(A, axis=1)
s = np.sqrt(s)
s = 1/s
D = np.diag(s)
D = D.astype(np.float32)
return np.sqrt(D)
3、Keras实现GCN卷积
我实现了一个基于Keras的GCN网络层,可以随时调用这个层。
4、其他代码内容
MAE、MSE、R2等指标,各种图形可视化都有代码。
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
def calculate_metrics(y_true, y_pred):
r2 = r2_score(y_true, y_pred)
mae = mean_absolute_error(y_true, y_pred)
mse = mean_squared_error(y_true, y_pred)
return r2, mae, mse
# 示例
y_true = yt
y_pred = yp
r2, mae, mse = calculate_metrics(y_true, y_pred)
print("R^2 Score:", r2)
print("Mean Absolute Error:", mae)
print("Mean Squared Error:", mse)