STGCN_IJCAI-18-master代码解读(四):base_model.py

解读base_model.py

from models.layers import *
from os.path import join as pjoin
import tensorflow._api.v2.compat.v1 as tf
tf.disable_v2_behavior()

build_model()函数

def build_model(inputs, n_his, Ks, Kt, blocks, keep_prob):
    '''
    Build the base model.
    :param inputs: placeholder.
    :param n_his: int, size of historical records for training.
    :param Ks: int, kernel size of spatial convolution.
    :param Kt: int, kernel size of temporal convolution.
    :param blocks: list, channel configs of st_conv blocks.
    :param keep_prob: placeholder.
    '''
    x = inputs[:, 0:n_his, :, :]

    # Ko>0: kernel size of temporal convolution in the output layer.
    Ko = n_his
    # ST-Block
    for i, channels in enumerate(blocks):
        x = st_conv_block(x, Ks, Kt, channels, i, keep_prob, act_func='GLU')
        Ko -= 2 * (Kt - 1)

    # Output Layer
    if Ko > 1:
        y = output_layer(x, Ko, 'output_layer')
    else:
        raise ValueError(f'ERROR: kernel size Ko must be greater than 1, but received "{Ko}".')

    tf.add_to_collection(name='copy_loss',
                         value=tf.nn.l2_loss(inputs[:, n_his - 1:n_his, :, :] - inputs[:, n_his:n_his + 1, :, :]))
    train_loss = tf.nn.l2_loss(y - inputs[:, n_his:n_his + 1, :, :])
    single_pred = y[:, 0, :, :]
    tf.add_to_collection(name='y_pred', value=single_pred)
    return train_loss, single_pred

这段代码是用于构建一个模型的TensorFlow函数,特别是一个基于时空卷积(ST-Conv)的神经网络模型。这种模型通常用于处理时间序列数据,其中包括空间和时间两个维度(例如交通流量预测)。

函数参数:

  • inputs: 输入数据的占位符(placeholder),通常是一个四维张量。
  • n_his: 历史记录的大小,用于训练模型。
  • Ks: 空间卷积的核大小。
  • Kt: 时间卷积的核大小。
  • blocks: 列表,包含ST-Conv块的通道配置。
  • keep_prob: 占位符,用于设置Dropout层的保持概率。

主要变量和过程:

  1. 提取历史数据x = inputs[:, 0:n_his, :, :]从输入数据中提取用于训练的历史记录。

  2. 初始化输出层的时间卷积核大小(Ko): 初始值设置为n_his

  3. 构建ST-Conv块

    • 遍历blocks列表,在每次迭代中调用st_conv_block函数来添加一个ST-Conv块。
    • 更新Ko以考虑新添加的卷积层。
  4. 输出层

    • 如果Ko大于1,则添加输出层,否则抛出错误。
  5. 损失函数和预测值

    • 添加一个L2损失(平方损失)到TensorFlow的’copy_loss’集合。
    • 计算训练损失(train_loss)为预测值(y)和真实值之间的L2损失。
    • 提取单步预测值(single_pred)并将其添加到TensorFlow的’y_pred’集合。

返回值:

  • train_loss: 训练损失,用于优化模型。
  • single_pred: 单步预测值,用于后续评估。

需要注意的是,该函数引用了一些外部函数(如st_conv_blockoutput_layer)和TensorFlow特定的操作(如tf.add_to_collectiontf.nn.l2_loss)。

model_save()函数

def model_save(sess, global_steps, model_name, save_path='./output/models/'):
    '''
    Save the checkpoint of trained model.
    :param sess: tf.Session().
    :param global_steps: tensor, record the global step of training in epochs.
    :param model_name: str, the name of saved model.
    :param save_path: str, the path of saved model.
    :return:
    '''
    saver = tf.train.Saver(max_to_keep=3)
    prefix_path = saver.save(sess, pjoin(save_path, model_name), global_step=global_steps)
    print(f'<< Saving model to {prefix_path} ...')

这个函数是用于保存TensorFlow模型的训练检查点(checkpoint)。这是在训练深度学习模型时非常常见的操作,因为这样可以在训练过程中或训练完成后保存模型的状态,以便于之后进行评估或继续训练。

函数参数:

  • sess: 一个TensorFlow会话对象(tf.Session()),用于运行TensorFlow计算图。
  • global_steps: 一个张量,记录了训练过程中的全局步骤(epochs)。
  • model_name: 保存的模型的名称,这通常是一个字符串。
  • save_path: 保存模型的路径,如果不指定,默认是'./output/models/'

主要操作:

  1. 创建 Saver 对象:使用tf.train.Saver(max_to_keep=3)创建一个Saver对象,max_to_keep=3意味着在文件系统中最多保留3个检查点文件。

  2. 保存模型

    • 调用Saver对象的save方法,将当前会话(sess)的状态保存到指定路径(save_pathmodel_name的组合)。
    • global_step=global_steps将全局步骤数添加到保存的文件名中,以区分不同的检查点。
  3. 打印保存信息:输出一条信息,告知用户模型已经被保存,并显示保存的路径。

注意点:

  • pjoin可能是os.path.join的别名,用于连接目录和文件名。这部分在代码中没有给出,但这是一个合理的猜测。

这个函数没有返回值,它的主要作用是副作用(即,保存模型到磁盘)。这样,在训练过程中或训练结束后,你可以从磁盘加载这些检查点来恢复模型状态。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值