【论文解读】2017 STGCN: Spatio-Temporal Graph Convolutional Networks

一、简介

使用历史速度数据预测未来时间的速度。同时用于序列学习的RNN(GRU、LSTM等)网络需要迭代训练,它引入了逐步累积的误差,并且RNN模型较难训练。为了解决以上问题,我们提出了新颖的深度学习框架STGCN,用于交通预测。

二、STGCN模型架构

2.1 整体架构图示

在这里插入图片描述

2.2 ST-Conv blocks

符号含义
M历史时间序列长度
n节点数
C i C_i Ci输入的channel 数
C o C_o Co输出的channel 数

2.2.1 TemporalConv: Gated CNNs 用于提取时间特征

Note: nn.Conv2d的输入 channel在第一维度

[ P Q ] = C o n v ( x ) ; o u t = P ⊙ σ ( Q ) [P Q] = Conv(x); \\ out = P \odot \sigma (Q) [PQ]=Conv(x);out=Pσ(Q)

  • x ∈ R C i × M × n x \in \mathbb{R}^{C_i \times M \times n } xRCi×M×n
  • [ P Q ] ∈ R 2 C o ∗ ( M − K t + 1 ) × n [\text{P Q}] \in \mathbb{R}^{2C_o * (M - K_t + 1) \times n } [P Q]R2Co(MKt+1)×n

示例代码:

class TCN(nn.Module):
    def __init__(self, c_in: int, c_out: int, dia: int=1):
        """TemporalConvLayer
        input_dim:  (batch_size, 1, his_time_seires_len, node_num)
        sample:     [b, 1, 144, 207]
        Args:
            c_in (int): channel in
            c_out (int): channel out
            dia (int, optional): 空洞卷积大小. Defaults to 1.
        """
        super(TCN, self).__init__()
        self.c_out = c_out * 2
        self.c_in = c_in
        self.conv = nn.Conv2d(
            c_in, self.c_out, (2, 1), 1, padding=(0, 0), dilation=dia
        )

    def forward(self, x):
        # [batch, channel, his_n, node_num] 
        #  仅在时间维度上进行卷积 
        c = self.c_out//2
        out = self.conv(x)
        if len(x.shape) == 3: # channel, his_n, node_num
            P = out[:c, :, :]
            Q = out[c:, :, :]
        else:
            P = out[:, :c, :, :]
            Q = out[:, c:, :, :]
        return P * torch.sigmoid(Q)

2.2.2 SpatialConv: Graph CNNs 提取空间信息

迭代定义的切比雪夫多项式

o u t = Θ ∗ G x = ∑ k = 0 K − 1 θ k T k ( L ~ ) x = ∑ k = 0 K − 1 W K , l z k , l out= \Theta_{* \mathcal{G}} x = \sum_{k=0}^{K-1}\theta_k T_k(\tilde{L})x=\sum_{k=0}^{K-1}W^{K, l}z^{k, l} out=ΘGx=k=0K1θkTk(L~)x=k=0K1WK,lzk,l

  • Z 0 , l = H l Z^{0, l} = H^{l} Z0,l=Hl
  • Z 1 , l = L ~ ⋅ H l Z^{1, l} = \tilde{L} \cdot H^{l} Z1,l=L~Hl
  • Z k , l = 2 ⋅ L ~ ⋅ Z k − 1 , l − Z k − 2 , l Z^{k, l} = 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l} Zk,l=2L~Zk1,lZk2,l
  • L ~ = 2 ( I − D ~ − 1 / 2 A ~ D ~ − 1 / 2 ) / λ m a x − I \tilde{L} = 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I L~=2(ID~1/2A~D~1/2)/λmaxI

论文: Recursive formulation for fast filtering

示例代码:

class STCN_Cheb(nn.Module):
    def __init__(self, c, A, K=2):
        """spation cov layer
        Args:
            c (int): hidden dimension
            A (adj matrix): adj matrix
        """
        super(STCN_Cheb, self).__init__()
        self.K = K
        self.lambda_max = 2
        self.tilde_L = self.get_tilde_L(A)
        self.weight = nn.Parameter(torch.empty((K * c, c)))
        self.bias = nn.Parameter(torch.empty(c))
        stdv = 1.0 / np.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def get_tilde_L(self, A):
        I = torch.diag(torch.Tensor([1] * A.size(0))).float().to(A.device)
        tilde_A = A + I 
        tilde_D = torch.diag(torch.pow(tilde_A.sum(axis=1), -0.5))
        return 2 / self.lambda_max * (I - tilde_D @ tilde_A @ tilde_D) - I

    def forward(self, x):
        # [batch, channel, his_n, node_num] -> [batch, node_num, his_n, channel] -> [batch, his_n, node_num, channel] 
        x = x.transpose(1, 3)
        x = x.transpose(1, 2)
        output = self.m_unnlpp(x)
        output = output @ self.weight + self.bias
        output = output.transpose(1, 2)
        output = output.transpose(1, 3)
        return torch.relu(output) 

    def m_unnlpp(self, feat):
        K = self.K
        X_0 = feat
        Xt = [X_0]
        # X_1(f)
        if K > 1:
            X_1 = self.tilde_L @ X_0
            # Append X_1 to Xt
            Xt.append(X_1)
        # Xi(x), i = 2...k
        for _ in range(2, K):
            X_i =  2 * self.tilde_L @ X_1 - X_0
            # Add X_1 to Xt
            Xt.append(X_i)
            X_1, X_0 = X_i, X_1
        # 合并数据
        Xt = torch.cat(Xt, dim=-1)
        return Xt

2.2.3 ST-Block

组合TCNSTCN_Cheb
v l + 1 = Γ 1 ∗ T l ReLU ( Θ ∗ G l ( Γ 0 ∗ T l v l ) ) v^{l+1} = \Gamma ^{l} _{1*\mathcal{T}} \text{ReLU}( \Theta ^l_{*\mathcal{G}} (\Gamma ^{l} _{0*\mathcal{T}} v^l) ) vl+1=Γ1TlReLU(ΘGl(Γ0Tlvl))

  • Γ 0 ∗ T l v l \Gamma ^{l} _{0*\mathcal{T}} v^l Γ0Tlvl: 第一个TCN
  • Θ ∗ G l \Theta ^l_{*\mathcal{G}} ΘGl : STCN_Cheb
  • Γ 1 ∗ T l v l \Gamma ^{l} _{1*\mathcal{T}} v^l Γ1Tlvl: 第二个TCN
class STBlock(nn.Module):
    def __init__(
        self,
        A,
        K=2,
        TST_channel: List=[64, 16, 64]
        T_dia: List=[2, 4]
    ):
        # St-Conv Block1[  TCN(64, 16*2)->SCN(16, 16)->TCN(16, 64*2) ] 
        super(STBlock, self).__init__()
        self.T1 = TCN(TST_channel[0], TST_channel[1], dia=T_dia[0])
        # STCN_Cheb out have relu
        self.S = STCN_Cheb(TST_channel[1], Lk=A, K=K)
        self.T2 = TCN(TST_channel[1], TST_channel[2], dia=T_dia[1])

    def forward(self, x):
        return self.T2(self.S(self.T1(x)))

三、简单复现

复现可以看笔者的github: train.ipynb
用的数据是metr-la.h5

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
交通流预测是指利用数据分析方法对交通网络中的车辆流量进行预测和调度的过程。而ST-GCN(Spatio-Temporal Graph Convolutional Networks)则是一种针对时空图网络的深度学习方法。以下是针对ST-GCN代码的简要解释: ST-GCN代码是基于Python开发的,其主要功能是实现对时空图网络数据的预测和训练。该代码主要包括以下几个部分: 1. 数据处理:ST-GCN首先需要对原始交通流量数据进行处理和预处理。代码中会包括数据读取、数据清洗、数据规范化等操作,以确保数据的准确性和一致性。 2. 模型设计:ST-GCN采用了时空图卷积网络作为核心模型。代码中会定义和实现时空图网络的结构,包括网络层数、节点连接方式、特征提取方式等。这些节点和边的信息被表示为二维矩阵,方便进行卷积操作。 3. 训练和优化:ST-GCN通过调整网络参数来进行训练和优化。代码中包括损失函数的定义、参数初始化、梯度下降等操作,以最大程度地拟合原始数据,提高预测准确度。 4. 预测:代码还包括预测功能,用于对输入数据进行预测和推断。通过输入当前的交通流量数据,ST-GCN会输出预测结果,即未来一段时间内的车辆流量分布。 总之,ST-GCN代码是一个基于时空图卷积网络的交通流预测的实现工具。通过编写和调试这些代码,我们可以更好地理解和应用深度学习方法来处理和预测交通流量数据。同时,还可以根据实际需求对代码进行自定义和扩展,以提高预测效果和应用性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值