STGCN_IJCAI-18-master代码解读(八):data_utils.py

解读data_utils.py

from utils.math_utils import z_score

import numpy as np
import pandas as pd

Dataset()类

class Dataset(object):
    def __init__(self, data, stats):
        self.__data = data
        self.mean = stats['mean']
        self.std = stats['std']

    def get_data(self, type):
        return self.__data[type]

    def get_stats(self):
        return {'mean': self.mean, 'std': self.std}

    def get_len(self, type):
        return len(self.__data[type])

    def z_inverse(self, type):
        return self.__data[type] * self.std + self.mean

这个代码定义了一个名为Dataset的Python类,该类用于存储和操作数据集。这个类特别适用于处理归一化后的数据,因为它还包括数据的统计信息(平均值和标准差)。下面是对每一部分的详细解释:

__init__(self, data, stats)

  • __init__是构造函数,用于初始化Dataset对象。
  • self.__data = data:这里,__data是一个私有变量,用于存储传入的data。这通常是一个字典,其中可能包含“训练”、“验证”和“测试”等多种类型的数据。
  • self.mean = stats['mean']self.std = stats['std']:这两行代码存储了数据的平均值和标准差,这些统计信息通常用于数据归一化。

get_data(self, type)

  • 此方法用于获取指定类型(如“train”、“test”、“val”等)的数据。
  • 它返回与指定type关联的数据。

get_stats(self)

  • 此方法返回数据的统计信息,即平均值和标准差。
  • 返回的是一个字典,其中包含两个键:'mean''std'

get_len(self, type)

  • 此方法用于获取指定类型(如“train”、“test”、“val”等)的数据长度。
  • 它返回与指定type关联的数据长度。

z_inverse(self, type)

  • 这个方法用于进行z分数反转换,即从归一化数据返回到原始数据。
  • 该方法使用存储在self.meanself.std中的平均值和标准差,以将归一化的数据转换回其原始范围。

通过这个Dataset类,您可以轻松地管理和操作您的数据,包括获取数据的不同部分(如训练、测试等)、获取数据的统计信息和长度,以及在需要时进行反归一化。

seq_gen()函数

def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1):
    '''
    Generate data in the form of standard sequence unit.
    :param len_seq: int, the length of target date sequence.
    :param data_seq: np.ndarray, source data / time-series.
    :param offset:  int, the starting index of different dataset type.
    :param n_frame: int, the number of frame within a standard sequence unit,
                         which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min).
    :param n_route: int, the number of routes in the graph.
    :param day_slot: int, the number of time slots per day, controlled by the time window (5 min as default).
    :param C_0: int, the size of input channel.
    :return: np.ndarray, [len_seq, n_frame, n_route, C_0].
    '''
    n_slot = day_slot - n_frame + 1

    tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0))
    for i in range(len_seq):
        for j in range(n_slot):
            sta = (i + offset) * day_slot + j
            end = sta + n_frame
            tmp_seq[i * n_slot + j, :, :, :] = np.reshape(data_seq[sta:end, :], [n_frame, n_route, C_0])
    return tmp_seq

这段代码定义了一个函数seq_gen,该函数用于生成标准的序列数据单元,适用于时间序列数据或图结构数据中。这种序列数据单元主要用于训练或评估机器学习模型(特别是时间序列模型,如LSTM或GRU,或者图神经网络模型)。

下面是参数和代码的详细解释:

参数:

  • len_seq:目标日期序列的长度。
  • data_seq:源数据或时间序列,通常是一个Numpy数组。
  • offset:不同数据集类型的起始索引。
  • n_frame:一个标准序列单元中的帧数(时间步)。这通常包括历史数据(n_his)和预测数据(n_pred)。
  • n_route:图中路由(节点)的数量。
  • day_slot:每天的时间槽数量,由时间窗口(默认为5分钟)控制。
  • C_0:输入通道的大小,默认为1。

代码逻辑:

  1. n_slot = day_slot - n_frame + 1:计算每天可用于生成标准序列单元的时间槽数量。

  2. tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0)):初始化一个全零的Numpy数组,用于存储生成的序列数据。

  3. 双重循环:

    • 外层循环:遍历目标日期序列。
    • 内层循环:遍历每天内的每个可用时间槽。

    在这两个循环中,对于每个时间槽和每个目标日期,函数都会:

    1. 计算源数据中相应序列的起始(sta)和结束(end)索引。
    2. 从源数据中提取相应的序列,并将其重塑为[n_frame, n_route, C_0]的形状。
    3. 将这个重塑后的数组存储在tmp_seq中的相应位置。
  4. 返回生成的序列数据tmp_seq

通过这个函数,您可以方便地从原始的时间序列或图结构数据中生成适用于模型训练或评估的标准序列单元。这对于进行时间序列预测或图结构数据分析非常有用。

data_gen()函数

def data_gen(file_path, data_config, n_route, n_frame=21, day_slot=288):
    '''
    Source file load and dataset generation.
    :param file_path: str, the file path of data source.
    :param data_config: tuple, the configs of dataset in train, validation, test.
    :param n_route: int, the number of routes in the graph.
    :param n_frame: int, the number of frame within a standard sequence unit,
                         which contains n_his = 12 and n_pred = 9 (3 /15 min, 6 /30 min & 9 /45 min).
    :param day_slot: int, the number of time slots per day, controlled by the time window (5 min as default).
    :return: dict, dataset that contains training, validation and test with stats.
    '''
    n_train, n_val, n_test = data_config
    # generate training, validation and test data
    try:
        data_seq = pd.read_csv(file_path, header=None).values
    except FileNotFoundError:
        print(f'ERROR: input file was not found in {file_path}.')

    seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot)
    seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot)
    seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route, day_slot)

    # x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation.
    x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)}

    # x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size].
    x_train = z_score(seq_train, x_stats['mean'], x_stats['std'])
    x_val = z_score(seq_val, x_stats['mean'], x_stats['std'])
    x_test = z_score(seq_test, x_stats['mean'], x_stats['std'])

    x_data = {'train': x_train, 'val': x_val, 'test': x_test}
    dataset = Dataset(x_data, x_stats)
    return dataset

这个函数data_gen主要用于从给定的CSV文件路径中加载数据,并根据特定的配置生成用于训练、验证和测试的数据集。

参数:

  • file_path: CSV文件的路径,该文件包含源数据。
  • data_config: 一个包含三个整数的元组(n_train, n_val, n_test),分别表示训练、验证和测试数据的天数。
  • n_route: 图中路由(或节点)的数量。
  • n_frame: 在一个标准序列单元中的帧(时间步)数量。
  • day_slot: 每天的时间槽数量,通常由时间窗口(例如,5分钟)控制。

代码逻辑:

  1. 加载数据:使用pd.read_csv函数从CSV文件中读取数据,并将其转换为NumPy数组。

  2. 生成训练、验证和测试序列

    • 使用前面定义的seq_gen函数来生成训练、验证和测试数据。
    • 对于训练数据,从源数据的开始索引(offset=0)生成。
    • 对于验证数据,从源数据的n_train索引处开始生成。
    • 对于测试数据,从源数据的n_train + n_val索引处开始生成。
  3. 计算统计数据

    • 计算训练数据集的平均值和标准差,存储在x_stats字典中。
  4. 标准化数据

    • 使用z分数方法(z-score)对训练、验证和测试数据进行标准化。
    • 假设z_score函数用于进行这种标准化(虽然代码没有给出这个函数,但从上下文可以推断出这一点)。
  5. 组织数据和统计信息

    • 将标准化后的训练、验证和测试数据存储在x_data字典中。
    • 使用前面定义的Dataset类创建一个数据集对象,并将x_datax_stats作为参数传递。
  6. 返回数据集对象:最后,函数返回这个Dataset对象,其中包含标准化后的训练、验证和测试数据,以及关于这些数据的统计信息。

通过这个data_gen函数,你可以方便地从一个CSV文件中准备出适用于机器学习模型(特别是时间序列模型或图神经网络)的训练、验证和测试数据。这个函数也处理了数据的标准化,这是许多机器学习算法所需要的。

gen_batch()函数

def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False):
    '''
    Data iterator in batch.
    :param inputs: np.ndarray, [len_seq, n_frame, n_route, C_0], standard sequence units.
    :param batch_size: int, the size of batch.
    :param dynamic_batch: bool, whether changes the batch size in the last batch if its length is less than the default.
    :param shuffle: bool, whether shuffle the batches.
    '''
    len_inputs = len(inputs)

    if shuffle:
        idx = np.arange(len_inputs)
        np.random.shuffle(idx)

    for start_idx in range(0, len_inputs, batch_size):
        end_idx = start_idx + batch_size
        if end_idx > len_inputs:
            if dynamic_batch:
                end_idx = len_inputs
            else:
                break
        if shuffle:
            slide = idx[start_idx:end_idx]
        else:
            slide = slice(start_idx, end_idx)

        yield inputs[slide]

这个函数gen_batch用于生成批量的数据,它是一个迭代器。主要应用在机器学习和深度学习训练中,用于将大量数据分成小批量(batch)进行逐批训练。

参数:

  • inputs: 一个NumPy数组,形状是 [len_seq, n_frame, n_route, C_0],代表标准的序列单元。
  • batch_size: 批量的大小,即每次迭代返回多少条数据。
  • dynamic_batch: 一个布尔值,表示是否在最后一个批次改变批大小,如果其长度小于默认的批大小。
  • shuffle: 一个布尔值,表示是否需要随机打乱批次。

代码逻辑:

  1. 获取输入长度: 使用len函数获取inputs数组的长度(即序列的总数)。

  2. 是否打乱数据:

    • 如果shuffle=True,那么会生成一个从0到len_inputs-1的整数数组,并随机打乱它。
  3. 生成批次:

    • 使用for循环从0开始,步长为batch_size,迭代到len_inputs
    • start_idx是每个批次的起始索引,end_idx是终止索引。
    • 如果end_idx大于len_inputs(超出数组长度),则根据dynamic_batch决定是否调整最后一个批次的大小。
  4. 数据切片与返回:

    • 使用yield关键字返回每个批次的数据。
    • 如果shuffle=True,则根据打乱后的索引返回数据。
    • 否则,直接使用切片从start_idxend_idx提取数据。

使用场景:

这种批次生成器非常适用于大规模数据集,特别是当你不能一次性将所有数据加载到内存中时。此外,通过设定shuffle=True,模型在训练时可以接触到随机化的数据,这通常有助于提高模型的泛化能力。

总体来说,这个函数提供了一种灵活和高效的方式来迭代大规模数据集,并且还支持动态批处理和数据洗牌。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值