解读base_model.py
from models.layers import *
from os.path import join as pjoin
import tensorflow._api.v2.compat.v1 as tf
tf.disable_v2_behavior()
build_model()函数
def build_model(inputs, n_his, Ks, Kt, blocks, keep_prob):
'''
Build the base model.
:param inputs: placeholder.
:param n_his: int, size of historical records for training.
:param Ks: int, kernel size of spatial convolution.
:param Kt: int, kernel size of temporal convolution.
:param blocks: list, channel configs of st_conv blocks.
:param keep_prob: placeholder.
'''
x = inputs[:, 0:n_his, :, :]
# Ko>0: kernel size of temporal convolution in the output layer.
Ko = n_his
# ST-Block
for i, channels in enumerate(blocks):
x = st_conv_block(x, Ks, Kt, channels, i, keep_prob, act_func='GLU')
Ko -= 2 * (Kt - 1)
# Output Layer
if Ko > 1:
y = output_layer(x, Ko, 'output_layer')
else:
raise ValueError(f'ERROR: kernel size Ko must be greater than 1, but received "{Ko}".')
tf.add_to_collection(name='copy_loss',
value=tf.nn.l2_loss(inputs[:, n_his - 1:n_his, :, :] - inputs[:, n_his:n_his + 1, :, :]))
train_loss = tf.nn.l2_loss(y - inputs[:, n_his:n_his + 1, :, :])
single_pred = y[:, 0, :, :]
tf.add_to_collection(name='y_pred', value=single_pred)
return train_loss, single_pred
这段代码是用于构建一个模型的TensorFlow函数,特别是一个基于时空卷积(ST-Conv)的神经网络模型。这种模型通常用于处理时间序列数据,其中包括空间和时间两个维度(例如交通流量预测)。
函数参数:
inputs
: 输入数据的占位符(placeholder),通常是一个四维张量。n_his
: 历史记录的大小,用于训练模型。Ks
: 空间卷积的核大小。Kt
: 时间卷积的核大小。blocks
: 列表,包含ST-Conv块的通道配置。keep_prob
: 占位符,用于设置Dropout层的保持概率。
主要变量和过程:
-
提取历史数据:
x = inputs[:, 0:n_his, :, :]
从输入数据中提取用于训练的历史记录。 -
初始化输出层的时间卷积核大小(Ko): 初始值设置为
n_his
。 -
构建ST-Conv块:
- 遍历
blocks
列表,在每次迭代中调用st_conv_block
函数来添加一个ST-Conv块。 - 更新
Ko
以考虑新添加的卷积层。
- 遍历
-
输出层:
- 如果
Ko
大于1,则添加输出层,否则抛出错误。
- 如果
-
损失函数和预测值:
- 添加一个L2损失(平方损失)到TensorFlow的’copy_loss’集合。
- 计算训练损失(
train_loss
)为预测值(y
)和真实值之间的L2损失。 - 提取单步预测值(
single_pred
)并将其添加到TensorFlow的’y_pred’集合。
返回值:
train_loss
: 训练损失,用于优化模型。single_pred
: 单步预测值,用于后续评估。
需要注意的是,该函数引用了一些外部函数(如st_conv_block
和output_layer
)和TensorFlow特定的操作(如tf.add_to_collection
和tf.nn.l2_loss
)。
model_save()函数
def model_save(sess, global_steps, model_name, save_path='./output/models/'):
'''
Save the checkpoint of trained model.
:param sess: tf.Session().
:param global_steps: tensor, record the global step of training in epochs.
:param model_name: str, the name of saved model.
:param save_path: str, the path of saved model.
:return:
'''
saver = tf.train.Saver(max_to_keep=3)
prefix_path = saver.save(sess, pjoin(save_path, model_name), global_step=global_steps)
print(f'<< Saving model to {prefix_path} ...')
这个函数是用于保存TensorFlow模型的训练检查点(checkpoint)。这是在训练深度学习模型时非常常见的操作,因为这样可以在训练过程中或训练完成后保存模型的状态,以便于之后进行评估或继续训练。
函数参数:
sess
: 一个TensorFlow会话对象(tf.Session()
),用于运行TensorFlow计算图。global_steps
: 一个张量,记录了训练过程中的全局步骤(epochs)。model_name
: 保存的模型的名称,这通常是一个字符串。save_path
: 保存模型的路径,如果不指定,默认是'./output/models/'
。
主要操作:
-
创建 Saver 对象:使用
tf.train.Saver(max_to_keep=3)
创建一个Saver对象,max_to_keep=3
意味着在文件系统中最多保留3个检查点文件。 -
保存模型:
- 调用Saver对象的
save
方法,将当前会话(sess
)的状态保存到指定路径(save_path
和model_name
的组合)。 global_step=global_steps
将全局步骤数添加到保存的文件名中,以区分不同的检查点。
- 调用Saver对象的
-
打印保存信息:输出一条信息,告知用户模型已经被保存,并显示保存的路径。
注意点:
pjoin
可能是os.path.join
的别名,用于连接目录和文件名。这部分在代码中没有给出,但这是一个合理的猜测。
这个函数没有返回值,它的主要作用是副作用(即,保存模型到磁盘)。这样,在训练过程中或训练结束后,你可以从磁盘加载这些检查点来恢复模型状态。