【ASTGCN】模型调试学习笔记--数据生成详解(超详细)

利用滑动窗口生成时间序列

原理图示:

PEMS04数据集为例。

  • 该数据集维度为:(16992,307,3)16992表示时间序列的长度,307为探测器个数,即图的顶点个数,3为特征数,即流量,速度、平均占用率。
  • 现在利用滑动窗口生成新时间序列,假设滑动窗口大小(每次滑动所取时间序列的多少)为4,滑动窗口步长(每次滑动几格)为1,如图1所示,每次取4个长度的数据(总长度为16992,也就是图1中的L),滑动1个长度取一次,之后将滑动窗口取到的数据合并成新数据,如图2所示。
    在这里插入图片描述
    图1
    在这里插入图片描述
    图2

函数操作

函数调用关系:

def read_and_generate_dataset(graph_signal_matrix_filename,
                              num_of_weeks, num_of_days,
                              num_of_hours, num_for_predict,
                              points_per_hour=12, save=False):
   def get_sample_indices(data_sequence, num_of_weeks, num_of_days, num_of_hours,
                       label_start_idx, num_for_predict, points_per_hour=12):
             def search_data(sequence_length, num_of_depend, label_start_idx,
                 num_for_predict, units, points_per_hour):

read_and_generate_dataset函数调用get_sample_indices函数,get_sample_indices函数再调用search_data函数。

search_data函数

  • 函数功能:获取每个滑动生成的窗口的索引的首尾。
  • 函数具体操作:
def search_data(sequence_length, num_of_depend, label_start_idx,
                num_for_predict, units, points_per_hour):
 ####参数说明          
 #sequence_length在源码中接收的参数是get_sample_indices传递过来的data_sequence.shape[0],即原始数据的shape(16992,307,3)
 的第一个维度,即16992
 #num_of_depend:生成近期周期或日周期或周周期,源码中默认为num_of_hours = 1
 #label_start_idx在源码中接收的参数是get_sample_indices传递过来的label_start_idx,而get_sample_indices中的
 label_start_idx是read_and_generate_dataset函数传递过来的idx,在read_and_generate_dataset中,idx
 是range(data_seq.shape[0]),即0~16991,所以search_data中的label_start_idx是0~16991,search_data是处于for循环中被调用的。
 #num_for_predict:要预测的时间步长,源码中为12,也就是一个小时
 #units在get_sample_indices函数中传过来的值有三个,分别是7 * 24241,即前文所说的滑动窗口的步长,也就是论文原文中的
 近期周期、日周期、周周期,1代表一个小时;24代表24个小时,即一天;7*24代表一周。
 #points_per_hour:一个小时的步长,12

 	# 如果points_per_hour小于0,则抛出一个ValueError异常,提示points_per_hour应该大于0
    if points_per_hour < 0:
        raise ValueError("points_per_hour should be greater than 0!")
        
    # 检查预测目标的起始索引加上要预测的时间步长是否超出了历史数据的长度,如果超出了历史数据的长度,则返回None,表示无法生成
    有效的索引范围;例如循环进行到idx(label_start_idx)=16981,此时16981+12>16992,则返回空。
    if label_start_idx + num_for_predict > sequence_length:
        return None
        
    # 创建一个空列表,用于存储生成的索引范围
    x_idx = []
    
    # 遍历依赖的数据点数量范围。在每次迭代中,计算当前依赖序列的起始索引start_idx和结束索引end_idx
    for i in range(1, num_of_depend + 1):#源码中num_of_hours为1,此循环只执行一次
        # 计算当前依赖序列的起始索引
        start_idx = label_start_idx - points_per_hour * units * i  # idx-12*1*1
        # 计算当前依赖序列的结束索引
        end_idx = start_idx + num_for_predict # start_idx+12
        # 检查计算得到的起始索引是否大于等于0。如果大于等于0,说明该序列在历史数据中是有效的,可以加入到结果列表中
        if start_idx >= 0:
            x_idx.append((start_idx, end_idx))
        else:
            return None
    # 检查生成的索引范围的数量是否与预期的依赖数据点数量相等。如果不相等,则说明生成的索引范围数量不正确,返回None
    if len(x_idx) != num_of_depend:
        return None
    # 将生成的索引范围列表进行反转,并返回
    return x_idx[::-1]

举例说明:

  • for i in range(1, num_of_depend + 1):中num_of_depend为1时(源码中就为1),表示使用一个小时(近期)的历史数据来预测。
    因此代码为:for i in range(1, 2):,则此循环只执行一次。

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=13
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=13,start_idx =13-12*1*1=1end_idx=start_idx +num_for_predict =1+12=13
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[1,13],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=14
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=14,start_idx =14-12*1*1=2end_idx=start_idx +num_for_predict =2+12=14
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[2,14],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

  • for i in range(1, num_of_depend + 1):中num_of_depend为2时,表示使用2个小时(近期)的历史数据来预测。
    因此代码为:for i in range(1, 3):,则此循环执行两次。

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=12for i in range(1, 3):循环执行到i=1
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=12,start_idx =12-12*1*1=0end_idx=start_idx +num_for_predict =0+12=12
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[0,12],len(x_idx)==1
      if len(x_idx) != num_of_depend:,1=1

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=12for i in range(1, 3):循环执行到i=2
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=13,start_idx =14-12*1*2=-10end_idx=start_idx +num_for_predict =-10+12=2
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 start_idx =-10<0,索引不会加入x_idx中。
      因此本次search_data函数调用(函数内部进行了两次for循环)返回None

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=24for i in range(1, 3):循环执行到i=1
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=24,start_idx =24-12*1*1=12end_idx=start_idx +num_for_predict =12+12=24
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[12,24],

    • 假设当前read_and_generate_dataset函数中的循环执行到idx=24for i in range(1, 3):循环执行到i=2
      start_idx = label_start_idx - points_per_hour * units * i :,则此时label_start_idx=24,start_idx =24-12*1*2=0end_idx=start_idx +num_for_predict =0+12=12
      if start_idx >= 0: x_idx.append((start_idx, end_idx)),此时 x_idx=[0,12],
      循环2次之后, x_idx=[12,24],[0,12],最后执行return x_idx[::-1]
      则本次函数调用返回的索引序列是[[0,12],[12,24]]

代码测试:

if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    for idx in range(data_seq.shape[0]):
        hour_indices=search_data(data_seq.shape[0],1,idx,12,1,12)#一个小时的索引
        print("hour_indice:",hour_indices)

输出:

hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: [(0, 12)]
hour_indice: [(1, 13)]
hour_indice: [(2, 14)]
hour_indice: [(3, 15)]
hour_indice: [(4, 16)]
hour_indice: [(5, 17)]
hour_indice: [(6, 18)]
hour_indice: [(7, 19)]
hour_indice: [(8, 20)]
hour_indice: [(9, 21)]
hour_indice: [(10, 22)]
hour_indice: [(11, 23)]
hour_indice: [(12, 24)]
...
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    for idx in range(data_seq.shape[0]):
        hour_indices=search_data(data_seq.shape[0],2,idx,12,1,12)#两个小时的索引
        print("hour_indice:",hour_indices)

输出:

hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: [(0, 12), (12, 24)]
hour_indice: [(1, 13), (13, 25)]
hour_indice: [(2, 14), (14, 26)]
hour_indice: [(3, 15), (15, 27)]
...
hour_indice: None
hour_indice: None
hour_indice: None
hour_indice: None

get_sample_indices函数

  • 函数功能:按近期、日周期、周周期获得样本数据。
  • 函数具体操作:
def get_sample_indices(data_sequence, num_of_weeks, num_of_days, num_of_hours,
                       label_start_idx, num_for_predict, points_per_hour=12):
                       
 ####参数说明          
 #data_sequence在源码中接收的参数是read_and_generate_dataset传递过来的data_seq,即原始数据(16992,307,3)
 #num_of_weeks:0
 #num_of_days:0
 #num_of_hours:1
 #label_start_idx在源码中接收的参数是read_and_generate_dataset传递过来的idx,在read_and_generate_dataset中,idx
 是range(data_seq.shape[0]),即0~16991,所以get_sample_indices中的label_start_idx是0~16991,get_sample_indices是
 处于for循环中被调用的。
 #num_for_predict:要预测的时间步长,源码中为12,也就是一个小时
 #points_per_hour:一个小时的步长,12
 
    week_sample, day_sample, hour_sample = None, None, None
    # 构建sample的区间限制,分界点
    
    #如果索引越界了,直接return,例如循环进行到idx(label_start_idx)=16981,此时16981+12>16992
    if label_start_idx + num_for_predict > data_sequence.shape[0]:
        return week_sample, day_sample, hour_sample, None
        
 	#num_of_weeks ,num_of_days ,num_of_hours 只能有一个大于0,因为只能同时构造一种时间序列数据
    if num_of_weeks > 0:
        week_indices = search_data(data_sequence.shape[0], num_of_weeks,
                                   label_start_idx, num_for_predict,
                                   7 * 24, points_per_hour)
        if not week_indices:
            return None, None, None, None

        week_sample = np.concatenate([data_sequence[i: j]
                                      for i, j in week_indices], axis=0)

    if num_of_days > 0:
        day_indices = search_data(data_sequence.shape[0], num_of_days,
                                  label_start_idx, num_for_predict,
                                  24, points_per_hour)
        if not day_indices:
            return None, None, None, None

        day_sample = np.concatenate([data_sequence[i: j]
                                     for i, j in day_indices], axis=0)
	#如果num_of_hours >0
    if num_of_hours > 0:
    	#生成hours切片数据,search_data函数的返回值为:[0,12][1,13],...,[[0,12],[1,13],...]等索引,
        hour_indices = search_data(data_sequence.shape[0], num_of_hours,
                                   label_start_idx, num_for_predict,
                                   1, points_per_hour)
        if not hour_indices:
            return None, None, None, None
		#按照索引在原始数据中提取数据
        hour_sample = np.concatenate([data_sequence[i: j]
                                      for i, j in hour_indices], axis=0)
    #生成标签
    target = data_sequence[label_start_idx: label_start_idx + num_for_predict]
    return week_sample, day_sample, hour_sample, target

举例说明

if __name__ == '__main__':
    data_seq = np.load(graph_signal_matrix_filename)['data']
    sample = []
    targetlist= []
    for idx in range(data_seq.shape[0]):
        hour_indice=search_data(data_seq.shape[0],1,idx,12,1,12)
       
        if not hour_indice:
            continue
        #从原始数据集中按索引取出对应的数据,并像图2那样拼在一起
        hour_sample = np.concatenate([data_seq[i: j]
                              for i, j in hour_indice], axis=0)
        sample.append(hour_sample)
        #从hour_sample的后num_for_predict个步长取出数据,作为标签
        target = data_seq[idx: idx + num_for_predict]
        targetlist.append(target)
        print("idx:",idx)
        print("hour_sample.shape:",hour_sample.shape)
    print("len(sample):",len(sample))
    print("sample[0].shape):",sample[0].shape)
    print("sample[0][0].shape:",sample[0][0].shape)
    print("len(targetlist):",len(targetlist))
    print("targetlist[0].shape:",targetlist[0].shape)
    print("targetlist[0][0].shape:",targetlist[0][0].shape)

部分输出

idx: 12
hour_sample.shape: (12, 307, 3)
idx: 13
hour_sample.shape: (12, 307, 3)
idx: 14
hour_sample.shape: (12, 307, 3)
idx: 15
hour_sample.shape: (12, 307, 3)
idx: 16
hour_sample.shape: (12, 307, 3)
idx: 17
hour_sample.shape: (12, 307, 3)
idx: 18
hour_sample.shape: (12, 307, 3)
...
len(sample): 16969
sample[0].shape): (12, 307, 3)
sample[0][0].shape: (307, 3)
len(targetlist): 16969
targetlist[0].shape: (12, 307, 3)
targetlist[0][0].shape: (307, 3)

这里的

idx: 12
hour_sample.shape: (12, 307, 3)

就是根据search_data函数生成的hour_indice: [(0, 12)]索引在原数据集中取得的。

read_and_generate_dataset函数

  • 函数功能:调用search_dataget_sample_indices函数,按近期、日周期、周周期获得样本标签,并把样本、标签、时间步都放入all_samples列表。
  • 函数具体操作:
def read_and_generate_dataset(graph_signal_matrix_filename,
                              num_of_weeks, num_of_days,
                              num_of_hours, num_for_predict,
                              points_per_hour=12, save=False):
 
    all_samples = []
    for idx in range(data_seq.shape[0]):
    
        sample = get_sample_indices(data_seq, num_of_weeks, num_of_days,
                                    num_of_hours, idx, num_for_predict,
                                    points_per_hour)
        if ((sample[0] is None) and (sample[1] is None) and (sample[2] is None)):
            continue
        week_sample, day_sample, hour_sample, target = sample
        sample = [] 
        
        # N表示传感器,F表示特征数,T表示时间段
        if num_of_weeks > 0:
            week_sample = np.expand_dims(week_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(week_sample)

        if num_of_days > 0:
            day_sample = np.expand_dims(day_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(day_sample)
		#把hour_sample(sample_i)进行维度变换
        if num_of_hours > 0:
            hour_sample = np.expand_dims(hour_sample, axis=0).transpose((0, 2, 3, 1))  # (1,N,F,T)
            sample.append(hour_sample)

        target = np.expand_dims(target, axis=0).transpose((0, 2, 3, 1))[:, :, 0, :]  # (1,N,T)
        sample.append(target)

        time_sample = np.expand_dims(np.array([idx]), axis=0)  # (1,1)
        sample.append(time_sample)

        all_samples.append(
            sample)  # sampe:[(week_sample),(day_sample),(hour_sample),target,time_sample] = [(1,N,F,Tw),(1,N,F,Td),(1,N,F,Th),(1,N,Tpre),(1,1)]
    # 60%作为训练,20%作为验证,20%作为测试
    split_line1 = int(len(all_samples) * 0.6)
    split_line2 = int(len(all_samples) * 0.8)
     
	training_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[:split_line1])]  # [(B,N,F,Tw),(B,N,F,Td),(B,N,F,Th),(B,N,Tpre),(B,1)]
    validation_set = [np.concatenate(i, axis=0)
                      for i in zip(*all_samples[split_line1: split_line2])]
    testing_set = [np.concatenate(i, axis=0)
                   for i in zip(*all_samples[split_line2:])]
    train_x = np.concatenate(training_set[:-2], axis=-1)  # (B,N,F,T')
    val_x = np.concatenate(validation_set[:-2], axis=-1)
    test_x = np.concatenate(testing_set[:-2], axis=-1)

    train_target = training_set[-2]  # (B,N,T)
    val_target = validation_set[-2]
    test_target = testing_set[-2]

    train_timestamp = training_set[-1]  # (B,1)
    val_timestamp = validation_set[-1]
    test_timestamp = testing_set[-1]

    (stats, train_x_norm, val_x_norm, test_x_norm) = normalization(train_x, val_x, test_x)

    all_data = {
        'train': {
            'x': train_x_norm,
            'target': train_target,
            'timestamp': train_timestamp,
        },
        'val': {
            'x': val_x_norm,
            'target': val_target,
            'timestamp': val_timestamp,
        },
        'test': {
            'x': test_x_norm,
            'target': test_target,
            'timestamp': test_timestamp,
        },
        'stats': {
            '_mean': stats['_mean'],
            '_std': stats['_std'],
        }
    }
    print('train x:', all_data['train']['x'].shape)
    print('train target:', all_data['train']['target'].shape)
    print('train timestamp:', all_data['train']['timestamp'].shape)
    print()
    print('val x:', all_data['val']['x'].shape)
    print('val target:', all_data['val']['target'].shape)
    print('val timestamp:', all_data['val']['timestamp'].shape)
    print()
    print('test x:', all_data['test']['x'].shape)
    print('test target:', all_data['test']['target'].shape)
    print('test timestamp:', all_data['test']['timestamp'].shape)
    print()
    print('train data _mean :', stats['_mean'].shape, stats['_mean'])
    print('train data _std :', stats['_std'].shape, stats['_std'])

    if save:
        file = os.path.basename(graph_signal_matrix_filename).split('.')[0]
        dirpath = os.path.dirname(graph_signal_matrix_filename)
        filename = os.path.join(dirpath, file + '_r' + str(num_of_hours) + '_d' + str(num_of_days) + '_w' + str(
            num_of_weeks)) + '_astcgn'
        print('save file:', filename)
        np.savez_compressed(filename,
                            train_x=all_data['train']['x'], train_target=all_data['train']['target'],
                            train_timestamp=all_data['train']['timestamp'],
                            val_x=all_data['val']['x'], val_target=all_data['val']['target'],
                            val_timestamp=all_data['val']['timestamp'],
                            test_x=all_data['test']['x'], test_target=all_data['test']['target'],
                            test_timestamp=all_data['test']['timestamp'],
                            mean=all_data['stats']['_mean'], std=all_data['stats']['_std']
                            )
    return all_data

sampletarget 在前面两个函数已经分析过,对于sample,在本函数中多了一个操作hour_sample = np.expand_dims(hour_sample, axis=0).transpose((0, 2, 3, 1))target = np.expand_dims(target, axis=0).transpose((0, 2, 3, 1))[:, :, 0, :]下面具体解释该操作。

  • np.expand_dims(hour_sample, axis=0):在指定的轴上插入一个新维度,hour_sample的原始形状是(12,307,3),经过expand_dims后变成(1,12,307,3)

  • target的原始形状是(12,307,3),经过expand_dims后变成(1,12,307,3)

  • .transpose((0, 2, 3, 1)):将hour_sample的维度按照指定的顺序重新排列,原来是(1,12,307,3),处理之后是(1,307,3,12)

  • .transpose((0, 2, 3, 1))[:, :, 0, :]:先将 target 变为(1,307,3,12),再提取 target 第三个维度的第一个特征,target 变为(1,307,1,12),即(1,307,12)

使用原始数据的数字对比:

  • sample[i]=(1,12,307,3)转为csv文件(12,307,1)只保存流量特征
#把sample[i]12, 307, 3)转为csv文件(12,307,1)只保存流量特征
import pandas as pd
def sampletocsv(i,sample):
    sample = sample[i]
    print("sample.shape:",sample.shape)
    #只提取流量
    reshaped_sample = sample[:, :, :1, :]
    print("Reshaped shape:", reshaped_sample.shape)
    data_2d = reshaped_sample.reshape(-1, reshaped_sample.shape[-1])
    df = pd.DataFrame(data_2d)
    df.to_csv(f'npytocsv/sample_dim1_{i}.csv', index=False)
if __name__ == '__main__':
    sampletocsv(0,sample)

输出如下,这里的每一份sample_i文件的维度都是(307,12),根据上面的get_sample_indices函数代码举例的输出,这样的sample_i一共有16969个。
在这里插入图片描述
sample_0:(部分)
在这里插入图片描述
sample_1:(部分)
在这里插入图片描述
sample_2:(部分)
在这里插入图片描述

  • targetlist[i](1, 307, 12)转为csv文件
#把target[i]1, 307, 12)转为csv文件
#运行search_data和get_sample_indices测试函数,不运行read_and_generate_dataset测试函数
import pandas as pd
def targetlisttocsv(i,targetlist):
    targetlist=targetlist[i]
    print("targetlist[i].shape:",targetlist.shape)
    targetlist_2d = targetlist.reshape(targetlist.shape[1], targetlist.shape[2])
    df = pd.DataFrame(targetlist_2d)
    df.to_csv(f'npytocsv/targetlist_{i}.csv', index=False)
if __name__ == '__main__':
    targetlisttocsv(2,targetlist)

在这里插入图片描述
targetlist_0:(部分)
在这里插入图片描述
targetlist_1:(部分)
在这里插入图片描述
targetlist_2:(部分)
在这里插入图片描述

  • 提取原始数据PEMS04.npz中的一部分转为csv文件
#把原始数据的流量的前i条的第j+1个特征转为csv
import pandas as pd
def dataseqtocsv(i,j):
    data_seq = np.load(graph_signal_matrix_filename)['data']
    print("data_seq.shape",data_seq.shape)

    subset_data_seq = data_seq[:i, :, j]
    print("subset_data_seq.shape:",subset_data_seq.shape)
    
    df = pd.DataFrame(subset_data_seq)
    df.to_csv(f'subset_data_seq{j+1}.csv', index=False)
if __name__ == '__main__':
    dataseqtocsv(50,0)

为了方便只输出16992条数据中的前50条,且只取流量特征,输出如下:
在这里插入图片描述
sampletarget联合起来与原数据集对比
在这里插入图片描述
在这里插入图片描述

  • 总结:get_sample_indices函数就是根据search_data函数所形成的索引,去原始数据集中提取对应的数据和标签,组成我们需要的近期(num_of_hours),日周期(num_of_days)、周周期(num_of_weeks)数据。

处理完hour_sample之后,把time_sample target 加入到all_samples中,

此时all_samples=[[hour_sample],[target],[time_sample],...,[hour_sample],[target],[time_sample]]

其中每一个 hour_sample=(1,307,3,12),每一个target =(1,307,),每一个time_sample =(1,1),且time_sample[0]=[[12]]time_sample[16969]=[[16980]]

接下来是按比例划分all_samples60%training_set20%validation_set20%testing_set

training_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[:split_line1])] 
validation_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[split_line1: split_line2])]
testing_set = [np.concatenate(i, axis=0)
                    for i in zip(*all_samples[split_line2:])]

training_set中取出从第一个到倒数第二个之间元素(不包括倒数第二个),即[hour_sample]=(1,307,3,12),并沿着最后一个轴(时间)连接起来,组成train_xval_x test_x 同理。

    train_x = np.concatenate(training_set[:-2], axis=-1)
    val_x = np.concatenate(validation_set[:-2], axis=-1)
    test_x = np.concatenate(testing_set[:-2], axis=-1)

training_set中取出倒数第二个元素,即target=(1,307,12),组成train_targetval_target test_target同理。

	train_target = training_set[-2]  
    val_target = validation_set[-2]
    test_target = testing_set[-2]

training_set中取出最后一个元素,即time_sample =(1,1),组成train_timestamp val_timestamp test_timestamp 同理。

    train_timestamp = training_set[-1]  
    val_timestamp = validation_set[-1]
    test_timestamp = testing_set[-1]

接下来是归一化操作,先看归一化函数。

normalization函数

  • 函数功能:对这输入的数据集进行标准化处理,使得每个数据集的均值为 0,标准差为 1
  • 函数具体操作:
def normalization(train, val, test):
 	#确保 train、val 和 test 数据集在第1轴及其后面的维度上形状相同
    assert train.shape[1:] == val.shape[1:] and val.shape[1:] == test.shape[1:]  # ensure the num of nodes is the same
    mean = train.mean(axis=(0, 1, 3), keepdims=True)
    std = train.std(axis=(0, 1, 3), keepdims=True)
    print('mean.shape:', mean.shape)
    print('std.shape:', std.shape)
    def normalize(x):
        return (x - mean) / std
    train_norm = normalize(train)
    val_norm = normalize(val)
    test_norm = normalize(test)
    return {'_mean': mean, '_std': std}, train_norm, val_norm, test_norm

mean = train.mean(axis=(0, 1, 3), keepdims=True):计算 train数据集在第013轴上的均值,并保持这些维度以便后续广播。
std = train.std(axis=(0, 1, 3), keepdims=True):计算 train数据集在第013轴上的标准差,并保持这些维度以便后续广播。
返回一个字典,和标准化的结果,接下来构建一系列字典。

整体梳理

for idx in range(16992): 0-16991
  search_data函数生成hour_indice: [(0, 12)]
  get_sample_indices函数根据hour_indice,从原始数据中取出索引从011的数据,添加到hour_sample中,维度: (12,307,3),经过维度扩展操作后变为:(1,12,307,3)
  从原始数据中取出索引从12-23的数据,添加到target中,维度: (12,307,3),经过维度扩展操作后变为:(1,12,307,3)
  取idx的值加添加到time_sample中:[0]
  把hour_sampletargettime_sample添加到samples
  把samples添加到all_samples中,all_samples=[[[hour_sample],[target],[samples]],...,[[hour_sample],[target],[samples]]]

  • 10
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值