图卷积神经网络笔记——第六章:(1)基于PyTorch的时序数据处理(交通流量数据)

在前面说了PyG这个框架,但是这个框架处理数据其实没那么简单,并且有时候我们想要改变底层的图卷积框架时就无能为力了,所以这一章说一下用PyTorch怎么写出图卷积并且实现交通流量数据的预测。但在这之前,需要先处理好需要的数据。
下一小节:链接

一、数据来源

数据来自美国的加利福尼亚州的洛杉矶市,第一个CSV文件是关于节点的表示情况,一共有307个节点,第二个npz文件是交通流量的文件,时间范围是两个月(2018.1.1——2018.2.28),每5分钟测一次。

下载地址:

链接:https://pan.baidu.com/s/1GlssEHKgf9agTRsdhPPuRA
提取码:q7v9

二、数据分析

在处理数据之前,先看看拿到的数据长什么样子,我们可视化数据看看:

import numpy as np
import matplotlib.pyplot as plt

def get_flow(file_name): # 将读取文件写成一个函数
   
    flow_data = np.load(file_name) # 载入交通流量数据
    print([key for key in flow_data.keys()]) # 打印看看key是什么
    
    print(flow_data["data"].shape)  # (16992, 307, 3),16992是时间(59*24*12),307是节点数,3表示每一维特征的维度(类似于二维的列)
    flow_data = flow_data['data'] # [T, N, D],T为时间,N为节点数,D为节点特征

    return flow_data

# 做工程、项目等第一步对拿来的数据进行可视化的直观分析
if __name__ == "__main__":

    traffic_data = get_flow("PeMS_04/PeMS04.npz")
    node_id = 10
    print(traffic_data.shape)
    
    plt.plot(traffic_data[:24*12, node_id, 0])  # 0维特征
    plt.savefig("node_{:3d}_1.png".format(node_id))

    plt.plot(traffic_data[:24 * 12, node_id, 1])  # 1维特征
    plt.savefig("node_{:3d}_2.png".format(node_id))

    plt.plot(traffic_data[:24 * 12, node_id, 2])  # 2维特征
    plt.savefig("node_{:3d}_3.png".format(node_id))

运行结果如下:
在这里插入图片描述
所以可得出:每个节点有三个特征,但是其他两个节点基本是平稳不变的,所以我们只取第一维特征。

三、数据处理

1、读入数据并取需要的特征

前面说了,只取第一维特征,并且为了后面方便,将节点的维度放在第一维,所以重写的get_flow()函数如下:

import csv
import torch
import numpy as np
from torch.utils.data import Dataset

def get_flow(file_name):

    flow_data = np.load(file_name)
    print([key for key in flow_data.keys()])
    print(flow_data["data"].shape)  # (16992, 307, 3),16992是时间(59*24*12),307是节点数,3表示每一维特征的维度(类似于二维的列)
    
    flow_data = flow_data['data'].transpose([1, 0, 2])[:, :, 0][:, :, np.newaxis]  # [N, T, D],transpose就是转置,让节点纬度在第0位,N为节点数,T为时间,D为节点特征
    # 对np.newaxis说一下,就是增加一个维度,这是因为一般特征比一个多,即使是一个,保持这样的习惯,便于通用的处理问题

    return flow_data

2、数据集处理:写成PyTorch所需要的数据集的类

说明:下面的所有代码都在traffic_dataset.py文件中,我为了说明把几个函数分开了,运行的时候只需要全复制到一个.py文件中即可。

(1)读入邻接矩阵

首先当然是邻接矩阵的读取,前面的.CSV文件其实就是一个直接能可视化的邻接矩阵,当然不是我们所需的那种,所以才需要处理嘛。如下所示,from和to表示的是节点,cost表示的是两个节点之间的直线距离(来表示权重),在本文中,权重都为1。

在这里插入图片描述
处理成邻接矩阵的程序如下:

import csv
import torch
import numpy as np
from torch.utils.data import Dataset

def get_adjacent_matrix(distance_file: str, num_nodes: int, id_file: str = None, graph_type="connect") -> np.array:
    """
    :param distance_file: str, path of csv file to save the distances between nodes.
    :param num_nodes: int, number of nodes in the graph
    :param id_file: str, path of txt file to save the order of the nodes.就是排序节点的绝对编号所用到的&#
  • 49
    点赞
  • 283
    收藏
    觉得还不错? 一键收藏
  • 92
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 92
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值