STGCN_IJCAI-18-master代码解读(六):tester.py

解读tester.py

from data_loader.data_utils import gen_batch
from utils.math_utils import evaluation
from os.path import join as pjoin

import tensorflow._api.v2.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import time

multi_pred()函数

def multi_pred(sess, y_pred, seq, batch_size, n_his, n_pred, step_idx, dynamic_batch=True):
    '''
    Multi_prediction function.
    :param sess: tf.Session().
    :param y_pred: placeholder.
    :param seq: np.ndarray, [len_seq, n_frame, n_route, C_0].
    :param batch_size: int, the size of batch.
    :param n_his: int, size of historical records for training.
    :param n_pred: int, the length of prediction.
    :param step_idx: int or list, index for prediction slice.
    :param dynamic_batch: bool, whether changes the batch size in the last one if its length is less than the default.
    :return y_ : tensor, 'sep' [len_inputs, n_route, 1]; 'merge' [step_idx, len_inputs, n_route, 1].
            len_ : int, the length of prediction.
    '''
    pred_list = []
    for i in gen_batch(seq, min(batch_size, len(seq)), dynamic_batch=dynamic_batch):
        # Note: use np.copy() to avoid the modification of source data.
        test_seq = np.copy(i[:, 0:n_his + 1, :, :])
        step_list = []
        for j in range(n_pred):
            pred = sess.run(y_pred,
                            feed_dict={'data_input:0': test_seq, 'keep_prob:0': 1.0})
            if isinstance(pred, list):
                pred = np.array(pred[0])
            test_seq[:, 0:n_his - 1, :, :] = test_seq[:, 1:n_his, :, :]
            test_seq[:, n_his - 1, :, :] = pred
            step_list.append(pred)
        pred_list.append(step_list)
    #  pred_array -> [n_pred, batch_size, n_route, C_0)
    pred_array = np.concatenate(pred_list, axis=1)
    return pred_array[step_idx], pred_array.shape[1]

这个函数主要用于在一个时间序列数据上执行多步预测。它使用给定的模型(y_pred placeholder)和输入序列(seq)进行预测,并返回预测结果。

参数:

  • sess: TensorFlow 的会话对象,用于执行模型预测。
  • y_pred: TensorFlow 的 placeholder,用于保存模型的预测结果。
  • seq: numpy 数组,包含输入序列,其形状为 [len_seq, n_frame, n_route, C_0]。
  • batch_size: 批处理大小。
  • n_his: 历史记录的大小,用于训练。
  • n_pred: 预测长度。
  • step_idx: 用于预测切片的索引,可以是整数或列表。
  • dynamic_batch: 布尔值,决定是否在最后一个批次长度小于默认批处理大小时改变批处理大小。

实现细节:

  1. 预测列表初始化: pred_list 用于存储每一批次的预测结果。

  2. 批处理生成与预测: 对于 seq 中的每一批数据,执行以下操作:

    • seq 中获取测试序列 test_seq
    • 初始化一个列表 step_list,用于存储单个批次的多步预测结果。
    • n_pred 步内,进行预测,并将每一步的预测结果存储在 step_list 中。
  3. 更新测试序列: 在每一步预测后,test_seq 会更新,以包含最新的预测值,以供下一步预测使用。

  4. 预测结果聚合: 所有批次的预测结果都存储在 pred_list 中,然后通过 np.concatenate 转换为 numpy 数组 pred_array

  5. 返回结果: 最终返回预测结果 pred_array 的特定切片(由 step_idx 定义)和预测的长度。

这个函数是多步时间序列预测任务的核心,它能批量地进行多步预测,并且可以处理动态批处理大小的情况。这对于在实际应用中预测交通流量、股票价格等具有时空依赖性的数据非常有用。

model_inference()函数

def model_inference(sess, pred, inputs, batch_size, n_his, n_pred, step_idx, min_va_val, min_val):
    '''
    Model inference function.
    :param sess: tf.Session().
    :param pred: placeholder.
    :param inputs: instance of class Dataset, data source for inference.
    :param batch_size: int, the size of batch.
    :param n_his: int, the length of historical records for training.
    :param n_pred: int, the length of prediction.
    :param step_idx: int or list, index for prediction slice.
    :param min_va_val: np.ndarray, metric values on validation set.
    :param min_val: np.ndarray, metric values on test set.
    '''
    x_val, x_test, x_stats = inputs.get_data('val'), inputs.get_data('test'), inputs.get_stats()

    if n_his + n_pred > x_val.shape[1]:
        raise ValueError(f'ERROR: the value of n_pred "{n_pred}" exceeds the length limit.')

    y_val, len_val = multi_pred(sess, pred, x_val, batch_size, n_his, n_pred, step_idx)
    evl_val = evaluation(x_val[0:len_val, step_idx + n_his, :, :], y_val, x_stats)

    # chks: indicator that reflects the relationship of values between evl_val and min_va_val.
    chks = evl_val < min_va_val
    # update the metric on test set, if model's performance got improved on the validation.
    if sum(chks):
        min_va_val[chks] = evl_val[chks]
        y_pred, len_pred = multi_pred(sess, pred, x_test, batch_size, n_his, n_pred, step_idx)
        evl_pred = evaluation(x_test[0:len_pred, step_idx + n_his, :, :], y_pred, x_stats)
        min_val = evl_pred
    return min_va_val, min_val

这个函数用于模型推断,并评估模型在验证集和测试集上的性能。它使用一个预先训练好的模型(由 TensorFlow 会话 sess 和 placeholder pred 表示)以及输入数据(由 Dataset 类的 inputs 实例提供)来进行操作。

参数:

  • sess: TensorFlow 的会话对象,用于执行模型预测。
  • pred: TensorFlow 的 placeholder,用于保存模型的预测结果。
  • inputs: Dataset 类的实例,提供数据源。
  • batch_size: 批处理大小。
  • n_his: 历史记录的大小,用于训练。
  • n_pred: 预测长度。
  • step_idx: 用于预测切片的索引,可以是整数或列表。
  • min_va_val: numpy 数组,保存验证集上的最小评估值。
  • min_val: numpy 数组,保存测试集上的最小评估值。

实现细节:

  1. 数据准备: 从 inputs 实例中获取验证集(x_val)和测试集(x_test)数据。

  2. 参数检查: 验证 n_hisn_pred 的和是否超过 x_val 的长度限制。

  3. 在验证集上进行多步预测: 使用 multi_pred 函数在验证集上进行多步预测,结果保存在 y_val

  4. 评估验证集上的性能: 使用 evaluation 函数计算 y_val 和实际值之间的误差,结果保存在 evl_val

  5. 性能检查与更新:

    • 检查 evl_val 是否小于 min_va_val(先前在验证集上获得的最佳性能指标)。
    • 如果有改进,则更新 min_va_val 和执行测试集上的预测以获取新的 min_val

这个函数的主要目的是通过不断地在验证集上进行评估来微调模型,并据此更新测试集上的性能指标。这是一种常见的机器学习实践,用于避免模型过拟合并在新数据上获得更好的性能。

model_test()函数

def model_test(inputs, batch_size, n_his, n_pred, inf_mode, load_path='./output/models/'):
    '''
    Load and test saved model from the checkpoint.
    :param inputs: instance of class Dataset, data source for test.
    :param batch_size: int, the size of batch.
    :param n_his: int, the length of historical records for training.
    :param n_pred: int, the length of prediction.
    :param inf_mode: str, test mode - 'merge / multi-step test' or 'separate / single-step test'.
    :param load_path: str, the path of loaded model.
    '''
    start_time = time.time()
    model_path = tf.train.get_checkpoint_state(load_path).model_checkpoint_path

    test_graph = tf.Graph()

    with test_graph.as_default():
        saver = tf.train.import_meta_graph(pjoin(f'{model_path}.meta'))

    with tf.Session(graph=test_graph) as test_sess:
        saver.restore(test_sess, tf.train.latest_checkpoint(load_path))
        print(f'>> Loading saved model from {model_path} ...')

        pred = test_graph.get_collection('y_pred')

        if inf_mode == 'sep':
            # for inference mode 'sep', the type of step index is int.
            step_idx = n_pred - 1
            tmp_idx = [step_idx]
        elif inf_mode == 'merge':
            # for inference mode 'merge', the type of step index is np.ndarray.
            step_idx = tmp_idx = np.arange(3, n_pred + 1, 3) - 1
        else:
            raise ValueError(f'ERROR: test mode "{inf_mode}" is not defined.')

        x_test, x_stats = inputs.get_data('test'), inputs.get_stats()

        y_test, len_test = multi_pred(test_sess, pred, x_test, batch_size, n_his, n_pred, step_idx)
        evl = evaluation(x_test[0:len_test, step_idx + n_his, :, :], y_test, x_stats)

        for ix in tmp_idx:
            te = evl[ix - 2:ix + 1]
            print(f'Time Step {ix + 1}: MAPE {te[0]:7.3%}; MAE  {te[1]:4.3f}; RMSE {te[2]:6.3f}.')
        print(f'Model Test Time {time.time() - start_time:.3f}s')
    print('Testing model finished!')

这个函数的目的是从保存的检查点(checkpoint)加载模型,并在测试集上进行预测和评估。这对于在训练完成后验证模型的性能很有用。

参数

  • inputs: Dataset 类的实例,它是测试数据的数据源。
  • batch_size: 批处理的大小。
  • n_his: 用于训练的历史记录长度。
  • n_pred: 预测的长度。
  • inf_mode: 推断模式,可以是 ‘sep’(单步测试)或 ‘merge’(多步测试)。
  • load_path: 加载模型的路径。

实现细节

  1. 时间记录: 用于测量测试所需的总时间。
  2. 模型加载: 使用 TensorFlow 的 get_checkpoint_stateimport_meta_graph 来加载预训练模型。
  3. 会话恢复: 在一个新的 TensorFlow 会话中恢复模型。
  4. 预测 Placeholder: 使用 get_collection 获取预测结果的 placeholder。
  5. 推断模式选择: 根据 inf_mode 参数决定使用哪种推断模式(‘sep’ 或 ‘merge’)。
  6. 获取测试数据: 从 inputs 对象中获取测试数据和统计信息。
  7. 多步预测: 调用 multi_pred 函数进行多步预测。
  8. 评估性能: 调用 evaluation 函数来评估模型在测试集上的性能。
  9. 输出结果: 打印每个时间步的 MAPE(平均绝对百分比误差)、MAE(平均绝对误差)和 RMSE(均方根误差)。
  10. 时间消耗: 计算并打印模型测试所需的时间。

该函数提供了一个完整的流程,从加载预训练模型,到在测试集上进行预测和评估,再到最终打印出性能指标和所需时间。这样可以很方便地了解模型在未见过的数据上的性能。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值