PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测

27 篇文章 49 订阅
20 篇文章 35 订阅

I. 前言

前面已经写过不少时间序列预测的文章:

  1. 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
  2. PyTorch搭建LSTM实现时间序列预测(负荷预测)
  3. PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测
  4. PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
  5. PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
  6. PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  7. PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  8. PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  9. PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  10. PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  11. PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
  12. PyTorch-LSTM时间序列预测中如何预测真正的未来值
  13. PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  14. PyTorch搭建ANN实现时间序列预测(风速预测)
  15. PyTorch搭建CNN实现时间序列预测(风速预测)
  16. PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  17. PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
  18. PyTorch时间序列预测系列文章总结(代码使用方法)
  19. TensorFlow搭建LSTM实现时间序列预测(负荷预测)
  20. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
  21. TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
  22. TensorFlow搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  23. TensorFlow搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  24. TensorFlow搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  25. TensorFlow搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  26. TensorFlow搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  27. TensorFlow搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  28. TensorFlow搭建ANN实现时间序列预测(风速预测)
  29. TensorFlow搭建CNN实现时间序列预测(风速预测)
  30. TensorFlow搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  31. PyG搭建图神经网络实现多变量输入多变量输出时间序列预测
  32. PyTorch搭建GNN-LSTM和LSTM-GNN模型实现多变量输入多变量输出时间序列预测
  33. PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测
  34. 时序预测中Attention机制是否真的有效?盘点LSTM/RNN中24种Attention机制+效果对比
  35. 详解Transformer在时序预测中的Encoder和Decoder过程:以负荷预测为例
  36. (PyTorch)TCN和RNN/LSTM/GRU结合实现时间序列预测
  37. PyTorch搭建Informer实现长序列时间序列预测
  38. PyTorch搭建Autoformer实现长序列时间序列预测

从第31篇文章起,本系列开始更新时空预测模型,其中前两篇文章都不是属于论文中的模型,今天介绍一个使用较为广泛的用于时序预测的时空图卷积网络STGCN。

II. STGCN

STGCN是北大发表在IJCAI 2018上的论文Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting中提出来的,其目的是用于实时的交通预测。

在该论文中使用的数据集为美国加州PeMSD7数据集,里面包含了分布在不同地方的228个传感器观测到的车流量,文章中使用这228个节点构成了一个无向图,然后根据历史的车流量信息预测未来某个时间段的所有传感器所在地的车流量信息。

可以看出,STGCN要解决的问题与前两篇文章要解决的问题基本一致。前两篇问题中,我们给出了13个变量前24小时的数据,目的是预测13个变量未来某几个小时的数据。在这里13个变量类比于228个传感器。

STGCN的原理也较为简单,STGCN由两个时空图卷积块(ST-Conv Block)和一个输出全连接层(Output Layer组成。其中ST-Conv Block又由两个时间门控卷积和中间的一个空间图卷积组成:
在这里插入图片描述
从图右边可知,两个Temporal Gated-Conv使用的是1-D卷积,和CNN处理一维时序信号类似,即进行seq_len维度上的卷积。Spatial Graph-Conv进行的是空域上的卷积,模型为GCN。

关于STGCN详细的原理可以阅读原论文,原理也比较简单。本篇文章不做太多详细的推导过程,主要讲解如何利用STGCN进行多变量输入多变量输出的时间序列预测。

III. PyG Temporal

PyG Temporal是PyG的一个扩展库,其主要用于处理时空信号数据,里面实现了许多使用较为广泛的时空图卷积模型如STGCN、DCRNN、T-GCN、LRGCN等。

PyG Temporal的安装也比较简单:

pip install torch-geometric-temporal

PyG Temporal中STGCN的实现如下:
在这里插入图片描述
参数解释如下:

  1. in_channels:节点输入特征的维度大小,这里为1,即每个节点都只有一个特征,我们需要预测的也是该特征。
  2. hidden_channels:字面意思。
  3. out_channels:字面意思。
  4. kernel_size:时域卷积时的卷积核大小,类比CNN即可。
  5. K:将切比雪夫多项式作为图卷积核时的卷积核大小,具体可以参考我之前写的一篇文章:ICML 2019 | SGC:简单图卷积网络
  6. normalization:拉普拉斯矩阵的归一化选项,前面也讲过了。
  7. bias:无需多述。

一个STConv所能接受的输入格式为:
在这里插入图片描述
可以看出,一个STConv需要接受三个输入:

  1. X:维度大小为(batch_size, seq_len, num_nodes, in_channels),在本文中即X=(256, 24, 13, 1)
  2. edge_index:图的邻接矩阵。
  3. edge_weight:边权重矩阵(可选)。

为此,我们可以首先搭建一个STGCN:

class STGCN(nn.Module):
    def __init__(self, num_nodes, size, K):
        super(STGCN, self).__init__()
        self.conv1 = STConv(num_nodes=num_nodes, in_channels=1, hidden_channels=16,
                            out_channels=32, kernel_size=size, K=K)
        self.conv2 = STConv(num_nodes=num_nodes, in_channels=32, hidden_channels=16,
                            out_channels=32, kernel_size=size, K=K)

    def forward(self, x, edge_index):
        # x(batch_size, seq_len, num_nodes, in_channels)
        x, edge_index = x.to(device), edge_index.to(device)
        x = F.elu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)

        return x

然后一个用于多变量输入多变量输出的STGCN模型搭建如下:

class STGCN_MLP(nn.Module):
    def __init__(self, args):
        super(STGCN_MLP, self).__init__()
        self.args = args
        self.out_feats = 128
        self.stgcn = STGCN(num_nodes=args.input_size, size=3, K=1)
        self.fcs = nn.ModuleList()
        for k in range(args.input_size):
            self.fcs.append(nn.Sequential(
                nn.Linear(16 * 32, 64),
                nn.ReLU(),
                nn.Linear(64, args.output_size)
            ))

    def forward(self, x, edge_index):
        # x(batch_size, seq_len, input_size)
        # x(512, 24, 13)--->(512, 24, 13, 1)
        x = x.unsqueeze(3)
        x = self.stgcn(x, edge_index)
        preds = []
        for k in range(x.shape[2]):
            preds.append(self.fcs[k](torch.flatten(x[:, :, k, :], start_dim=1)))

        pred = torch.stack(preds, dim=0)

        return pred

照例简单分析一下模型的处理过程:

首先我们有x=(batch_size=256, seq_len=24, input_size=13),为了满足STGCN的输入要求(batch_size, seq_len, num_nodes, in_channels=1),我们需要将x扩展一个维度:

x = x.unsqueeze(3)

然后经过STGCN:

x = self.stgcn(x, edge_index)

得到x=(256, 16, 13, 32)。操作过程与CNN类似,一维卷积作用在seq_len=24维度,最终变成16。随后,为了得到每个变量的输出,我们简单地将13个变量各自的(16, 32)经过13个不同的全连接层。

IV. 模型训练/测试

这点与前面一致,不再赘述。

预测效果相当不错:
在这里插入图片描述
预测效果示意图(只给出前6个变量):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

V. 代码

后续考虑整理公开。

  • 15
    点赞
  • 88
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 17
    评论
利用pyg库可以实现时间序列预测Pyg是一个强大的Python库,被广泛用于图神经网络(graph neural networks, GNNs)的开发和研究。它提供了许多用于处理图数据的功能和模型。 要使用pyg库进行时间序列预测,我们首先需要将时间序列数据转化为图数据的形式。一种常见的方法是将每个时间点视为图中的一个节点,并通过边连接相邻时间点的节点。然后,我们可以使用pyg库的函数和模型来处理和预测时间序列数据。 首先,我们可以使用pyg库的`torch_geometric.data.Data`类来表示图数据。我们可以使用这个类来创建一个包含节点特征、边索引和边特征的图对象。对于时间序列数据,我们可以将每个时间点的特征作为节点特征,并使用相邻时间点的索引作为边索引。 然后,我们可以使用pyg库的模型来预测时间序列的未来值。比如,我们可以使用图卷积神经网络(graph convolutional neural network, GCN)模型。该模型可以从图数据中学习节点的表示,并进行预测。我们可以使用pyg库的`torch_geometric.nn`模块来创建和训练GCN模型。 在预测过程中,我们可以根据需要选择不同的损失函数和优化器,以优化模型的性能。我们可以使用pyg库提供的损失函数和优化器,如MSE损失函数和Adam优化器。 总之,利用pyg库可以方便地处理和预测时间序列数据。通过将时间序列转化为图数据的形式,并使用pyg库提供的函数和模型来处理,我们可以实现精确的时间序列预测

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值