🔍 探索交通预测的未来 🌐
在现代城市交通管理中,准确的交通预测至关重要。为了应对这一挑战,我们开发了一个基于深度学习的时空图卷积网络(STGCN),旨在提升交通流量预测的准确性和效率。本文将详细介绍我们的代码实现及其主要改进点。
🚀 代码亮点
-
模块化设计:
- TimeBlock:该模块对每个节点应用时序卷积,快速捕捉时间维度的特征。
- STGCNBlock:结合时序卷积和图卷积,能够有效地建模节点之间的空间关系。
- 全局结构:STGCN模型整合了多个时空卷积块,能够处理多节点、多时间步的数据输入,输出未来交通流量预测结果。
-
import math import torch import torch.nn as nn import torch.nn.functional as F class TimeBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): """ :param in_channels: Number of input features at each node in each time step. :param out_channels: Desired number of output channels at each node in each time step. :param kernel_size: Size of the 1D temporal kernel. """ super(TimeBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) def forward(self, X): # Convert into NCHW format for pytorch to perform convolutions. X = X.permute(0, 3, 1, 2) temp = self.conv1(X) + torch.sigmoid(self.conv2(X)) out = F.relu(temp + self.conv3(X)) # Convert back from NCHW to NHWC out = out.permute(0, 2, 3, 1) return out class STGCNBlock(nn.Module): def __init__(self, in_channels, spatial_channels, out_channels, num_nodes): super(STGCNBlock, self).__init__() self.temporal1 = TimeBlock(in_channels=in_channels, out_channels=out_channels) self.Theta1 = nn.Parameter(torch.FloatTensor(out_channels, spatial_channels)) self.temporal2 = TimeBlock(in_channels=spatial_channels, out_channels=out_channels) self.batch_norm = nn.BatchNorm2d(num_nodes) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.Theta1.shape[1]) self.Theta1.data.uniform_(-stdv, stdv) def forward(self, X, A_hat): t = self.temporal1(X) lfs = torch.einsum("ij,jklm->kilm", [A_hat, t.permute(1, 0, 2, 3)]) # t2 = F.relu(torch.einsum("ijkl,lp->ijkp", [lfs, self.Theta1])) t2 = F.relu(torch.matmul(lfs, self.Theta1)) t3 = self.temporal2(t2) return self.batch_norm(t3) # return t3 class STGCN(nn.Module): def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output): super(STGCN, self).__init__() self.block1 = STGCNBlock(in_channels=num_features, out_channels=64, spatial_channels=16, num_nodes=num_nodes) self.block2 = STGCNBlock(in_channels=64, out_channels=64, spatial_channels=16, num_nodes=num_nodes) self.last_temporal = TimeBlock(in_channels=64, out_channels=64) self.fully = nn.Linear((num_timesteps_input - 2 * 5) * 64, num_timesteps_output) def forward(self, A_hat, X): out1 = self.block1(X, A_hat) out2 = self.block2(out1, A_hat) out3 = self.last_temporal(out2) out4 = self.fully(out3.reshape((out3.shape[0], out3.shape[1], -1))) return out4
-
主要改进点:
- CBAM模块:我们引入了通道和空间注意力机制(CBAM),增强了模型对重要特征的关注能力,从而提升了预测的准确性。
- 残差块:通过使用残差结构,解决了深层网络训练中的梯度消失问题,提升了模型的收敛速度和性能。
- Dropout层:添加了Dropout层以防止过拟合,增强了模型的泛化能力,使其在实际应用中表现更加稳健。
-
import math import torch import torch.nn as nn import torch.nn.functional as F from CBAM import CBAM class TimeBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): """ :param in_channels: Number of input features at each node in each time step. :param out_channels: Desired number of output channels at each node in each time step. :param kernel_size: Size of the 1D temporal kernel. """ super(TimeBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) def forward(self, X): # Convert into NCHW format for pytorch to perform convolutions. X = X.permute(0, 3, 1, 2) temp = self.conv1(X) + torch.sigmoid(self.conv2(X)) out = F.relu(temp + self.conv3(X)) # Convert back from NCHW to NHWC out = out.permute(0, 2, 3, 1) return out class STGCNBlock(nn.Module): def __init__(self, in_channels, spatial_channels, out_channels, num_nodes, dropout_prob=0.3): super(STGCNBlock, self).__init__() self.temporal1 = TimeBlock(in_channels=in_channels, out_channels=out_channels) self.Theta1 = nn.Parameter(torch.FloatTensor(out_channels, spatial_channels)) self.temporal2 = TimeBlock(in_channels=spatial_channels, out_channels=out_channels) self.batch_norm = nn.BatchNorm2d(num_nodes) self.dropout = nn.Dropout(dropout_prob) # 添加Dropout self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.Theta1.shape[1]) self.Theta1.data.uniform_(-stdv, stdv) def forward(self, X, A_hat): t = self.temporal1(X) lfs = torch.einsum("ij,jklm->kilm", [A_hat, t.permute(1, 0, 2, 3)]) t2 = F.relu(torch.matmul(lfs, self.Theta1)) t3 = self.temporal2(t2) return self.batch_norm(self.dropout(t3)) # 应用Dropout # return t3 class STGCN(nn.Module): def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output): """ :param num_nodes: Number of nodes in the graph. :param num_features: Number of features at each node in each time step. :param num_timesteps_input: Number of past time steps fed into the network. :param num_timesteps_output: Desired number of future time steps output by the network. """ super(STGCN, self).__init__() self.block1 = STGCNBlock(in_channels=num_features, out_channels=64, spatial_channels=16, num_nodes=num_nodes) self.block2 = STGCNBlock(in_channels=64, out_channels=64, spatial_channels=16, num_nodes=num_nodes) self.last_temporal = TimeBlock(in_channels=64, out_channels=64) self.fully = nn.Linear((num_timesteps_input - 2 * 5) * 64, num_timesteps_output) self.cbam = CBAM(64) def forward(self, A_hat, X): """ :param X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features=in_channels). :param A_hat: Normalized adjacency matrix. """ out1 = self.block1(X, A_hat) out2 = self.block2(out1, A_hat) out3 = self.last_temporal(out2) out3 = out3.permute(0, 3, 1, 2) cbam_output = self.cbam(out3) out3 = cbam_output.permute(0, 2, 3, 1) out4 = self.fully(out3.reshape((out3.shape[0], out3.shape[1], -1))) return out4
-
高效性:
- 使用图卷积代替传统的全连接网络,减少模型参数,提高训练速度。
- 模型结构设计合理,能够在大规模的数据集上表现出色。
-
易于扩展:
- 该框架易于添加更多模块和功能,满足不同的交通预测需求。
📊 实验结果
我们的实验表明,STGCN模型在多个真实世界交通数据集上均优于现有的最先进方法,成功捕捉了复杂的时空相关性。
💡 未来展望
我们将继续优化模型,以实现更高的预测精度,并探索其在其他领域(如公共交通、物流等)的应用潜力。
🖼️ 模型架构图
![STGCN架构][]📥 获取代码:欢迎访问我们的:AI代码Insights,探索更多功能并参与讨论!👉 查看代码让我们共同推动交通预测技术的发展!🚗💨