GCN-LSTM实现时空预测

        简介:现有的预测模型越来考虑时间和空间的相关性,统称为时空预测。这种预测模型往往比简单的序列模型(例如RNN、LSTM、GRU及其变体)、Transformer等效果更好。我使用Keras实现了该GCN-LSTM代码,因为Keras相比于torch更容易入手和理解。我实现了一个基于Keras的GCN网络层,可以像Keras中调用Dense、LSTM等网络层一样随时调用这个层。需要电脑安装tensorflow和keras。keras的版本为2.3。tensorflow的版本为2.1。

1、模型的输入数据形状

        由于是时空数据,模型的输入形状为:[批次大小,时间步长,节点个数,维度数量]。具体的来说,就是[batchsize,node,time,dim]。批次大小表示一次性喂给模型的样本数量、节点个数就是图中的实体节点数量、时间步长就是每个节点记录的一段时间序列数据(可以是单维度、多维度的)的长度、维度数量就是每个时刻点记录的变量个数。

        如下图所示,输入为[None,30,11,6],表示批次小自动调节,输入的时间段为30步,一共有11个节点、每个时刻有6个特征。具体思路为:首先对于每个时刻,进行图卷积,实现节点之间的信息传递;其次,使用LSTM压缩所有时刻的信息到一个一维张量。

2、邻接矩阵的构建和预处理

        邻接矩阵往往需要根据所研究的具体场景决定,比如每辆车看作是节点、每条路段看作是节点、每个行人看作是节点、或者每个通道特征图看作是节点也是可以的。

        对于边来说,就是考虑节点之间的关系。计算GCN计算公式中归一化邻接矩阵D的代码如下。

def getAdj(A):
    """
    A:自带自环的矩阵
    """
    s = np.sum(A, axis=1)
    s = np.sqrt(s)
    s = 1/s
    D = np.diag(s)
    D = D.astype(np.float32)
    return np.sqrt(D)

3、Keras实现GCN卷积

        我实现了一个基于Keras的GCN网络层,可以随时调用这个层。

4、其他代码内容

        MAE、MSE、R2等指标,各种图形可视化都有代码。

from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

def calculate_metrics(y_true, y_pred):
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    return r2, mae, mse

# 示例
y_true = yt
y_pred = yp
r2, mae, mse = calculate_metrics(y_true, y_pred)

print("R^2 Score:", r2)
print("Mean Absolute Error:", mae)
print("Mean Squared Error:", mse)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值