Temporal Fusion Transformersfor Interpretable Multi-horizon Time Series Forecasting代码解读(tensoreflow)

本文深入解读 Temporal Fusion Transformer (TFT) 的框架和代码实现,该模型用于多步时间序列预测,特别适用于异常检测。文章详细介绍了TFT的结构,包括Observed_Inputs、Known_Inputs、Static_Inputs的分解,以及关键组件如线性层、门限单元、残差网络和自注意力机制。通过TensorFlow代码实现,讨论了环境要求、文件夹结构、数据处理和模型搭建。TFT的优势在于其GRN和可解释的多头自注意力机制,提高模型性能和学习效率。
摘要由CSDN通过智能技术生成

论文来源:https://arxiv.org/pdf/1912.09363.pdf  

代码来源:google-research/tft at master · google-research/google-research · GitHub        

目录

1、框架介绍 

2、代码详解

2.1 tensorflow环境要求

2.2 文件夹框架

         2.3 数据结构

         2.4 tft_model.py框架搭建

        1、def linear_layer()

        2、apply_mlp() 

        3、def apply_gating_layer()

        4、def add_and_norm()

        5、def gated_residual_network()

        6、def get_decoder_mask()

        7、class ScaledDotProductAttention()

        8、 class InterpretableMultiHeadAttention()

        9、class TemporalFusionTransformer()

        10、def _batch_sampled_data()等

         11、def _build_base_graph()

        12、def build_model(self)

        13、def fit(self)

        14、def evaluate(self)

        15、……接下来的诸如predict函数都是一般神经网络的基本步骤,没什么特殊(我也写不动了)。

        2.5 script_train_fixed_params.py

          3、总结



1、框架介绍 

        框架的话,我直接粘了论文的原图。

        TFT用于时序预测,也有异常预测等具体应用。

        如上图,TFT将原始的时序数据分解为三部分:Observed_Inputs、Known_Inputs、Static_inputs。其中Observed_Inputs(已观测输入)即历史KPI数据,且已知这些数据的Target(输出);Known_Inputs指所有条目都已知的数据(包括历史的以及接下来需要预测的),例如时间戳;Static_Inputs指静态输入,本人理解为离散输入,对预测结果的影响不大的输入,比如CPU占用率数据中的计算机类别ID。

        简单地介绍一下用上述三个数据完成异常预测:首先将数据集分割为上述三部分,其次分别训练历史已观测输入、历史已知数据与历史静态输入 学习得到 目标输出(异常标签),测试过程中输入已知数据,通过已学习得的静态输入与已观测输入的特征矩阵,预测相应输出。

        看到这里,相信大家一定也有很多具体实现的疑惑,下面将通过代码介绍具体框架介绍。

2、代码详解

        打开script_download_data.py与script_train_fixed_params.py,将其中add_argument函数中expt_name的default设置为electricity(或其他),运行script_download_data.py下载数据集,运行script_train_fixed_params.py实现TFT。

        2.1 tensorflow环境要求

        因为这里的代码是由tensorflow1版本完成的,而现在大部分使用的都是tensorflow2。因此,需要对调用的tensorflow代码进行相应更改。

(1)更改  import tensorflow as tf

import tensorflow.compat.v1 as tf

(2)model文件头添加:

tf.compat.v1.experimental.output_all_intermediates(True)

        说实话,这是代码报错后要求的,我也不知道什么原理,有机会再调研一下。

        2.2 文件夹框架

        data_formatters文件夹主要完成文件的下载与预处理

        expt_settings文件夹完成各种参数的配置,前期不需要太关注 

        libs文件夹中tft_model.py文件实现神经网络框架的搭建

        script_download_data.py 下载原始数据集,script_train_fixed_params.py 运行默认参数的TFT,script_hyperparam_opt.py是具体调参的TFT实现。

         script_download_data.py较简单,不做详解,具体介绍script_train_fixed_params.py中TFT的实现流程。

        2.3 数据结构

        传统数据集由时间戳、KPI具体值与输出值等组成。data_formatters文件夹的favorita.py等实现原始数据集的预处理。

        去除时间戳与序号后的数据列可分为以下几类:

        列为输入类型,行为数据类型

observed_input  known_input      target              static     
real_value
category 

         2.4 tft_model.py框架搭建

        这里,我们一个函数一个函数讲。

        1、def linear_layer()

        定义Dense线性层。但相比Dense,增加了一个TimeDistributed层,在每个时间步上均操作Dense。

        2、apply_mlp() 

        定义两层Dense,MLP多层感知器。

        3、def apply_gating_layer()

        定义GLU门限单元,这个在论文中有提到:

        具体操作即Dropout后,分别定义激活函数为sigmoid与无激活函数的Dense层,将两Dense层的输出矩阵相乘即获得门限单元。

        门限单元的作用即门限,相当

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值