论文来源:https://arxiv.org/pdf/1912.09363.pdf
代码来源:google-research/tft at master · google-research/google-research · GitHub
目录
5、def gated_residual_network()
7、class ScaledDotProductAttention()
8、 class InterpretableMultiHeadAttention()
9、class TemporalFusionTransformer()
15、……接下来的诸如predict函数都是一般神经网络的基本步骤,没什么特殊(我也写不动了)。
2.5 script_train_fixed_params.py
1、框架介绍
框架的话,我直接粘了论文的原图。
TFT用于时序预测,也有异常预测等具体应用。
如上图,TFT将原始的时序数据分解为三部分:Observed_Inputs、Known_Inputs、Static_inputs。其中Observed_Inputs(已观测输入)即历史KPI数据,且已知这些数据的Target(输出);Known_Inputs指所有条目都已知的数据(包括历史的以及接下来需要预测的),例如时间戳;Static_Inputs指静态输入,本人理解为离散输入,对预测结果的影响不大的输入,比如CPU占用率数据中的计算机类别ID。
简单地介绍一下用上述三个数据完成异常预测:首先将数据集分割为上述三部分,其次分别训练历史已观测输入、历史已知数据与历史静态输入 学习得到 目标输出(异常标签),测试过程中输入已知数据,通过已学习得的静态输入与已观测输入的特征矩阵,预测相应输出。
看到这里,相信大家一定也有很多具体实现的疑惑,下面将通过代码介绍具体框架介绍。
2、代码详解
打开script_download_data.py与script_train_fixed_params.py,将其中add_argument函数中expt_name的default设置为electricity(或其他),运行script_download_data.py下载数据集,运行script_train_fixed_params.py实现TFT。
2.1 tensorflow环境要求
因为这里的代码是由tensorflow1版本完成的,而现在大部分使用的都是tensorflow2。因此,需要对调用的tensorflow代码进行相应更改。
(1)更改 import tensorflow as tf
import tensorflow.compat.v1 as tf
(2)model文件头添加:
tf.compat.v1.experimental.output_all_intermediates(True)
说实话,这是代码报错后要求的,我也不知道什么原理,有机会再调研一下。
2.2 文件夹框架
data_formatters文件夹主要完成文件的下载与预处理
expt_settings文件夹完成各种参数的配置,前期不需要太关注
libs文件夹中tft_model.py文件实现神经网络框架的搭建
script_download_data.py 下载原始数据集,script_train_fixed_params.py 运行默认参数的TFT,script_hyperparam_opt.py是具体调参的TFT实现。
script_download_data.py较简单,不做详解,具体介绍script_train_fixed_params.py中TFT的实现流程。
2.3 数据结构
传统数据集由时间戳、KPI具体值与输出值等组成。data_formatters文件夹的favorita.py等实现原始数据集的预处理。
去除时间戳与序号后的数据列可分为以下几类:
列为输入类型,行为数据类型
observed_input | known_input | target | static | |
real_value | ||||
category |
2.4 tft_model.py框架搭建
这里,我们一个函数一个函数讲。
1、def linear_layer()
定义Dense线性层。但相比Dense,增加了一个TimeDistributed层,在每个时间步上均操作Dense。
2、apply_mlp()
定义两层Dense,MLP多层感知器。
3、def apply_gating_layer()
定义GLU门限单元,这个在论文中有提到:
具体操作即Dropout后,分别定义激活函数为sigmoid与无激活函数的Dense层,将两Dense层的输出矩阵相乘即获得门限单元。
门限单元的作用即门限,相当