在前面说了PyG这个框架,但是这个框架处理数据其实没那么简单,并且有时候我们想要改变底层的图卷积框架时就无能为力了,所以这一章说一下用PyTorch怎么写出图卷积并且实现交通流量数据的预测。但在这之前,需要先处理好需要的数据。
下一小节:链接
文章目录
一、数据来源
数据来自美国的加利福尼亚州的洛杉矶市,第一个CSV文件是关于节点的表示情况,一共有307个节点,第二个npz文件是交通流量的文件,时间范围是两个月(2018.1.1——2018.2.28),每5分钟测一次。
下载地址:
链接:https://pan.baidu.com/s/1GlssEHKgf9agTRsdhPPuRA
提取码:q7v9
二、数据分析
在处理数据之前,先看看拿到的数据长什么样子,我们可视化数据看看:
import numpy as np
import matplotlib.pyplot as plt
def get_flow(file_name): # 将读取文件写成一个函数
flow_data = np.load(file_name) # 载入交通流量数据
print([key for key in flow_data.keys()]) # 打印看看key是什么
print(flow_data["data"].shape) # (16992, 307, 3),16992是时间(59*24*12),307是节点数,3表示每一维特征的维度(类似于二维的列)
flow_data = flow_data['data'] # [T, N, D],T为时间,N为节点数,D为节点特征
return flow_data
# 做工程、项目等第一步对拿来的数据进行可视化的直观分析
if __name__ == "__main__":
traffic_data = get_flow("PeMS_04/PeMS04.npz")
node_id = 10
print(traffic_data.shape)
plt.plot(traffic_data[:24*12, node_id, 0]) # 0维特征
plt.savefig("node_{:3d}_1.png".format(node_id))
plt.plot(traffic_data[:24 * 12, node_id, 1]) # 1维特征
plt.savefig("node_{:3d}_2.png".format(node_id))
plt.plot(traffic_data[:24 * 12, node_id, 2]) # 2维特征
plt.savefig("node_{:3d}_3.png".format(node_id))
运行结果如下:
所以可得出:每个节点有三个特征,但是其他两个节点基本是平稳不变的,所以我们只取第一维特征。
三、数据处理
1、读入数据并取需要的特征
前面说了,只取第一维特征,并且为了后面方便,将节点的维度放在第一维,所以重写的get_flow()函数如下:
import csv
import torch
import numpy as np
from torch.utils.data import Dataset
def get_flow(file_name):
flow_data = np.load(file_name)
print([key for key in flow_data.keys()])
print(flow_data["data"].shape) # (16992, 307, 3),16992是时间(59*24*12),307是节点数,3表示每一维特征的维度(类似于二维的列)
flow_data = flow_data['data'].transpose([1, 0, 2])[:, :, 0][:, :, np.newaxis] # [N, T, D],transpose就是转置,让节点纬度在第0位,N为节点数,T为时间,D为节点特征
# 对np.newaxis说一下,就是增加一个维度,这是因为一般特征比一个多,即使是一个,保持这样的习惯,便于通用的处理问题
return flow_data
2、数据集处理:写成PyTorch所需要的数据集的类
说明:下面的所有代码都在traffic_dataset.py文件中,我为了说明把几个函数分开了,运行的时候只需要全复制到一个.py文件中即可。
(1)读入邻接矩阵
首先当然是邻接矩阵的读取,前面的.CSV文件其实就是一个直接能可视化的邻接矩阵,当然不是我们所需的那种,所以才需要处理嘛。如下所示,from和to表示的是节点,cost表示的是两个节点之间的直线距离(来表示权重),在本文中,权重都为1。
处理成邻接矩阵的程序如下:
import csv
import torch
import numpy as np
from torch.utils.data import Dataset
def get_adjacent_matrix(distance_file: str, num_nodes: int, id_file: str = None, graph_type="connect") -> np.array:
"""
:param distance_file: str, path of csv file to save the distances between nodes.
:param num_nodes: int, number of nodes in the graph
:param id_file: str, path of txt file to save the order of the nodes.就是排序节点的绝对编号所用到的,这里排好了,不需要
:param graph_type: str, ["connect", "distance"],这个就是考不考虑节点之间的距离