Graph WaveNet数据集输入格式与输入自定数据集训练

本文介绍了如何使用GraphWaveNet模型处理交通数据,包括使用METR-LA数据集,生成邻接矩阵,以及将SUMO仿真数据转换为HDF5格式,最后详细说明了训练过程中的参数调整。
摘要由CSDN通过智能技术生成

Graph WaveNet代码链接:GitHub - nnzhan/Graph-WaveNet: graph wavenet

论文使用的数据集是METR-LA数据集,下载方式及如何运行可以参考我的另一篇文章Graph WaveNet代码入门详解-CSDN博客

数据输入类型

METR-LA

洛杉矶高速路数据集,"metr-la.h5",原始数据shape=(34272,207),即共34272个时间点数据,207个sensor。

邻接矩阵

pkl文件

如何拿自己的数据集放入GWNT模型训练?

我想要用GWNT模型预测交通排队长度。我用SUMO仿真软件创建了一个包括361个节点(lane)库网,运行20000s,每2s收集一次各个lane的排队长度。共9900个数据点。

邻接矩阵

首先生成路网的邻接矩阵,让有连接关系或者属于同一个edge的lane邻接矩阵元素为1并将邻接矩阵输出pkl格式

adj = np.zeros((laneNum, laneNum))
for lane in lane_net:
    lane_index = laneindexMap[lane]
    for outflow in lane_net[lane]['outflowLane']:
        outflow_index = laneindexMap[outflow]
        adj[lane_index][outflow_index] = 1
    for sameEdge in lane_net[lane]['sameEdgelane']:
        adj_index = laneindexMap[sameEdge]
        adj[lane_index][adj_index] = 1

with open('adjacency_matrix.pkl', 'wb') as f:
    pickle.dump(adj, f)

生成数据的h5文件

已有数据集格式:每个lane有一个csv file,命名为"lane_id.csv"

如:

生成h5文件

import pandas as pd
import numpy as np
import os

# Directory where CSV files are stored
csv_dir = 'output_data_file'  # Change this to the path where your CSV files are stored

# List all CSV files in the directory
csv_files = [f for f in os.listdir(csv_dir) if f.endswith('.csv')]

time_index = np.arange(200, 20000, 2)  # Starting at 200, ending at 20000, stepping by 5
full_df = pd.DataFrame(index=time_index)

# Process each CSV file
for file_name in csv_files:
    lane_name = file_name.replace('.csv', '')  # Extract the lane name from the file name
    file_path = os.path.join(csv_dir, file_name)
    df = pd.read_csv(file_path)

    # Set the index to 'Time(s)'
    df.set_index('Time(s)', inplace=True)
    # Reindex the DataFrame to the full time range, filling any missing data with NaN
    df = df.reindex(time_index, fill_value=np.nan)  # Use NaN for missing data

    # Add the data to the full DataFrame
    full_df[lane_name] = df['QueueLength(m)']

# Interpolate to fill in NaN values if necessary
# full_df.interpolate(method='linear', inplace=True)

# Now save the DataFrame into an HDF5 file in the required format
h5_file_path = 'data.h5'  # Change this to your desired output path
full_df.to_hdf(h5_file_path, key='df', mode='w')

print(f"Saved to {h5_file_path}")

运行generate_training_data.py文件

!python /content/drive/MyDrive/Graph-WaveNet-master/generate_training_data.py --output_dir='/content/drive/MyDrive/data' --traffic_df_filename='/content/drive/MyDrive/data.h5'

在generate_train_val_test函数处输出df,可以看到数据被处理成一个dataframe格式,x shape也是[batch_size, seq_length, num_nodes, channels] 格式

其他修改的地方

pkl文件只有adj matrix,因此需要修改所有加载pkl文件赋值的问题(因为提供数据的pkl文件还存储了sensor_ids, sensor_id_to_ind两个信息,这两个变量在后续训练等地方没有用处)

utils.py

def load_adj(pkl_filename, adjtype):
    adj_mx = load_pickle(pkl_filename)
    if adjtype == "scalap":
        adj = [calculate_scaled_laplacian(adj_mx)]
    elif adjtype == "normlap":
        adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32).todense()]
    elif adjtype == "symnadj":
        adj = [sym_adj(adj_mx)]
    elif adjtype == "transition":
        adj = [asym_adj(adj_mx)]
    elif adjtype == "doubletransition":
        adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))]
    elif adjtype == "identity":
        adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)]
    else:
        error = 0
        assert error, "adj type not defined"
    return adj

train.py中main函数修改

adj_mx = util.load_adj(args.adjdata,args.adjtype)

 训练

!python '/content/drive/MyDrive/Graph-WaveNet-master (1)/train.py' --data '/content/drive/MyDrive/data' --adjdata '/content/drive/MyDrive/Graph-WaveNet-master/adjacency_matrix.pkl' --device 'cuda:0' --in_dim '1' --num_nodes '361' --save '/content/drive/MyDrive/garage/' --epochs '200'

训练时修改相应的输入参数即可

  • 10
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值