时间序列|Temporal Fusion Transformer

今天分享一篇来自Google的时序论文

Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting.

首先来谈谈一个好的时序模型需要考虑的东西可能有哪些:

  • 单变量、多变量时间序列通吃
  • 既能利用序列自身的历史信息,也能利用一些其他变量的信息(不仅仅是自回归咯
  • 除了动态信息,静态信息也能利用,比如一些统计特征
  • 能够适应复杂模式以及简单模式的时间序列
  • 能够多步预测显然比单步预测再来递推来得好(避免累积误差
  • 不仅是点预测,还能知道预测值的不确定区间
  • Explainability is all we need!

模型结构一览

沉住气,先别被图吓到,真的不难

我们按照自底而上的角度先来粗看一下这个模型做的是什么

  • Level 1 输入层

三部分:静态信息、历史(待预测变量)信息、未来(其他变量)信息

  • Level 2 变量选择层

说白了就是要做特征筛选

  • Level 3 LSTM编码层

既然是时间序列,LSTM来捕捉点长短期信息合情合理

  • Level 4 Gate + Add&Norm

门控可以理解是在进一步考虑不同特征的重要性,残差和normalization常规操作了

  • Level 5 GRN

跟Level4基本一样,可以理解就是在加深网络

  • Level 6 Attention

对不同时刻的信息进行加权

  • Level 7 输出层

做的是分位数回归,可以预测区间了

这样从输入到输出有个粗略了解后,接下对网络里一些关键模块进行解剖!

由于变量选择里其实用到了后面的GRN模块,所以我们从GRN开始

Gated Residual Network (GRN)

这里面可能让有些人不熟悉的应该只有ELU (Exponential Linear Unit) 和 GLU (Gated Linear Units)

  • ELU这个激活函数可以取到负值,相比于Relu这让单元激活均值可以更接近0,类似于Batch Normalization的效果但是只需要更低的计算复杂度。同时在输入取较小值时具有软饱和的特性,提升了对噪声的鲁棒性。
  • GLU做的时其实就是一个加权求和操作

ℎ(�)=(�∗�+�)⊗σ(�∗�+�)

输入X,W, b, V, c是需要学的参数,可以理解就是对输入做完仿射变换后,再进行加权,这个权重也是输入的仿射变换进行归一化(过一下sigmoid)

class GatedLinearUnit(nn.Module):
    """Gated Linear Unit"""

    def __init__(self, input_size: int, hidden_size: int = None):
        super().__init__()

        self.hidden_size = hidden_size or input_size
        self.fc = nn.Linear(input_size, self.hidden_size * 2)

    def forward(self, x):
        x = self.fc(x)
        x = F.glu(x, dim=-1)
        return x
  • 残差和normalization
class AddNorm(nn.Module):
    def __init__(self, input_size: int):
        super().__init__()

        self.norm = nn.LayerNorm(self.input_size)

    def forward(self, x: torch.Tensor, skip: torch.Tensor):
        output = self.norm(x + skip)
        return output

class GateAddNorm(nn.Module):
    def __init__(self,input_size: int, hidden_size: int = None):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size or input_size
        self.glu = GatedLinearUnit(self.input_size, hidden_size=self.hidden_size)
        self.add_norm = AddNorm(self.hidden_size)

    def forward(self, x, skip):
        output = self.glu(x)
        output = self.add_norm(output, skip)
        return output

整个模块拼起来

class GatedResidualNetwork(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        context_size: int = None,
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.context_size = context_size
        self.hidden_size = hidden_size

        self.fc1 = nn.Linear(self.input_size, self.hidden_size)
        self.elu = nn.ELU()

        if self.context_size is not None:
            self.context = nn.Linear(self.context_size, self.hidden_size, bias=False)

        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)

        self.gate_norm = GateAddNorm(
            input_size=self.hidden_size,
            hidden_size=self.output_size
        )


    def forward(self, x, context=None, residual=None):
        x = self.fc1(x)
        if context is not None:
            context = self.context(context)
            x = x + context
        x = self.elu(x)
        x = self.fc2(x)
        x = self.gate_norm(x, residual)
        return x

了解了GRN这个基础模块后看一下变量选择怎么做

Variable Selection Network (VSN)

这里的变量选择其实是一种soft的选择方式,并不是剔除不重要变量,而是对变量进行加权,权重越大的代表重要性越高。(可以理解跟注意力机制的意思差不多,只是换种计算注意力系数的方式)

这里考虑多变量情况,不单独考虑单变量情况

class VariableSelectionNetwork(nn.Module):
    def __init__(
        self,
        input_sizes: Dict[str, int],
        hidden_size: int,
        input_embedding_flags: Dict[str, bool] = {},
        context_size: int = None,
        single_variable_grns: Dict[str, GatedResidualNetwork] = {},
        prescalers: Dict[str, nn.Linear] = {},
    ):
        """
        Calcualte weights for ``num_inputs`` variables  which are each of size ``input_size``
        """
        super().__init__()

        self.hidden_size = hidden_size
        self.input_sizes = input_sizes
        self.input_embedding_flags = input_embedding_flags
        self.context_size = context_size


        if self.context_size is not None:
            self.flattened_grn = GatedResidualNetwork(
                self.input_size_total,
                min(self.hidden_size, self.num_inputs),
                self.num_inputs,
                self.context_size,
            )
        else:
            self.flattened_grn = GatedResidualNetwork(
                self.input_size_total,
                min(self.hidden_size, self.num_inputs),
                self.num_inputs,
            )

        self.single_variable_grns = nn.ModuleDict()
        self.prescalers = nn.ModuleDict()
        for name, input_size in self.input_sizes.items():
            if name in single_variable_grns:
                self.single_variable_grns[name] = single_variable_grns[name]
            elif self.input_embedding_flags.get(name, False):
                self.single_variable_grns[name] = ResampleNorm(input_size, self.hidden_size)
            else:
                self.single_variable_grns[name] = GatedResidualNetwork(
                    input_size,
                    min(input_size, self.hidden_size),
                    output_size=self.hidden_size,
                )
            if name in prescalers:  # reals need to be first scaled up
                self.prescalers[name] = prescalers[name]
            elif not self.input_embedding_flags.get(name, False):
                self.prescalers[name] = nn.Linear(1, input_size)

        self.softmax = nn.Softmax(dim=-1)


    def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
        # transform single variables
        var_outputs = []
        weight_inputs = []
        for name in self.input_sizes.keys():
            # select embedding belonging to a single input
            variable_embedding = x[name]
            if name in self.prescalers:
                variable_embedding = self.prescalers[name](variable_embedding)
            weight_inputs.append(variable_embedding)
            var_outputs.append(self.single_variable_grns[name](variable_embedding))
        var_outputs = torch.stack(var_outputs, dim=-1)

        # 计算权重
        flat_embedding = torch.cat(weight_inputs, dim=-1)
        sparse_weights = self.flattened_grn(flat_embedding, context)
        sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)

        outputs = var_outputs * sparse_weights # 加权和
        outputs = outputs.sum(dim=-1)

        return outputs, sparse_weights
  • LSTM 编码器

这里调用Pytorch自带模块就行咯

self.lstm_encoder = LSTM(
            input_size=self.hparams.hidden_size,
            hidden_size=self.hparams.hidden_size,
            num_layers=self.hparams.lstm_layers,
            dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
            batch_first=True,
        )

然后就是注意力机制的设计了

  • Attention

文章对多头注意力机制的改进在于共享部分参数,即对于每一个head,Q和K都有分别的线性变换矩阵,但是V变换矩阵是共享的

class InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, n_head: int, d_model: int, dropout: float = 0.0):
        super(InterpretableMultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_k = self.d_q = self.d_v = d_model // n_head
        self.dropout = nn.Dropout(p=dropout)

        self.v_layer = nn.Linear(self.d_model, self.d_v)
        self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q) for _ in range(self.n_head)])
        self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k) for _ in range(self.n_head)])
        self.attention = ScaledDotProductAttention()
        self.w_h = nn.Linear(self.d_v, self.d_model, bias=False)

    def forward(self, q, k, v, mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        attns = []
        vs = self.v_layer(v) # 共享的
        for i in range(self.n_head):
            qs = self.q_layers[i](q)
            ks = self.k_layers[i](k)
            head, attn = self.attention(qs, ks, vs, mask)
            head_dropout = self.dropout(head)
            heads.append(head_dropout)
            attns.append(attn)

        head = torch.stack(heads, dim=2) if self.n_head > 1 else heads[0]
        attn = torch.stack(attns, dim=2)

        outputs = torch.mean(head, dim=2) if self.n_head > 1 else head
        outputs = self.w_h(outputs)
        outputs = self.dropout(outputs)

        return outputs, attn

其中注意力机制的计算还是标准的ScaledDotProductAttention,用QK计算出注意力系数,然后再来对V加权一下

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout: float = None, scale: bool = True):
        super(ScaledDotProductAttention, self).__init__()
        if dropout is not None:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = dropout
        self.softmax = nn.Softmax(dim=2)
        self.scale = scale

    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.permute(0, 2, 1))  # query-key overlap

        if self.scale:
            dimension = torch.as_tensor(k.size(-1), dtype=attn.dtype, device=attn.device).sqrt()
            attn = attn / dimension

        if mask is not None:
            attn = attn.masked_fill(mask, -1e9)
        attn = self.softmax(attn)

        if self.dropout is not None:
            attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn

分位数分归

换个损失函数就是了

分位数损失的定义 ���(�∗(�−�����),(1−�)∗(�����−�))

def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # calculate quantile loss
    self.quantiles = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
    losses = []
    for i, q in enumerate(self.quantiles):
        errors = target - y_pred[..., i]
        losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
    losses = torch.cat(losses, dim=2)

当然要用MSE也没问题的

了解清楚各个模块的功能后,那就开始搭积木了

模型汇总

为了顺着模型结构走,我们从forward顺下来看看

  • 首先是对输入做一些处理才能送入模型呀(不理解它的处理也没关系,根据自己的数据来就好了)

对应的就是

encoder_lengths = x["encoder_lengths"]
decoder_lengths = x["decoder_lengths"]
x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1)  # concatenate in time dimension
x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1)  # concatenate in time dimension
timesteps = x_cont.size(1)  # encode + decode length
max_encoder_length = int(encoder_lengths.max())
input_vectors = self.input_embeddings(x_cat)
input_vectors.update(
	{
		name: x_cont[..., idx].unsqueeze(-1)
		for idx, name in enumerate(self.hparams.x_reals)
		if name in self.reals
	}
)
  • 接下来就是变量选择了(会需要一些embedding的转换

# Embedding and variable selection
if len(self.static_variables) > 0:
	# static embeddings will be constant over entire batch
	static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables}
	static_embedding, static_variable_selection = self.static_variable_selection(static_embedding)
else:
	static_embedding = torch.zeros(
		(x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype, device=self.device
	)
	static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype, device=self.device)

static_context_variable_selection = self.expand_static_context(
	self.static_context_variable_selection(static_embedding), timesteps
)

embeddings_varying_encoder = {
	name: input_vectors[name][:, :max_encoder_length] for name in self.encoder_variables
}
embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(
	embeddings_varying_encoder,
	static_context_variable_selection[:, :max_encoder_length],
)

embeddings_varying_decoder = {
	name: input_vectors[name][:, max_encoder_length:] for name in self.decoder_variables  # select decoder
}
embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(
	embeddings_varying_decoder,
	static_context_variable_selection[:, max_encoder_length:],
)
  • 再来是LSTM编码,注意这边采用静态信息来初始化LSTM

# LSTM
# calculate initial state
input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand(
	self.hparams.lstm_layers, -1, -1
)
input_cell = self.static_context_initial_cell_lstm(static_embedding).expand(self.hparams.lstm_layers, -1, -1)

# run local encoder
encoder_output, (hidden, cell) = self.lstm_encoder(
	embeddings_varying_encoder, (input_hidden, input_cell), lengths=encoder_lengths, enforce_sorted=False
)

# run local decoder
decoder_output, _ = self.lstm_decoder(
	embeddings_varying_decoder,
	(hidden, cell),
	lengths=decoder_lengths,
	enforce_sorted=False,
)
  • 残差连接(红色箭头

# skip connection over lstm
lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output)
lstm_output_encoder = self.post_lstm_add_norm_encoder(lstm_output_encoder, embeddings_varying_encoder)

lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output)
lstm_output_decoder = self.post_lstm_add_norm_decoder(lstm_output_decoder, embeddings_varying_decoder)

lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1)
  • 进入attention前先进行一下静态信息的增强

# static enrichment
static_context_enrichment = self.static_context_enrichment(static_embedding)
attn_input = self.static_enrichment(
	lstm_output, self.expand_static_context(static_context_enrichment, timesteps)
)
  • 然后计算多头注意力,同样它也有残差连接,然后再经过一个GRN输出

# Attention
attn_output, attn_output_weights = self.multihead_attn(
	q=attn_input[:, max_encoder_length:],  # query only for predictions
	k=attn_input,
	v=attn_input,
	mask=self.get_attention_mask(
		encoder_lengths=encoder_lengths, decoder_length=timesteps - max_encoder_length
	),
)
# skip connection over attention
attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:])

output = self.pos_wise_ff(attn_output)
  • 最后可以看到还有个比较远的残差连接,然后经过Dense层输出

output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:])
if self.n_targets > 1:  # if to use multi-target architecture
	output = [output_layer(output) for output_layer in self.output_layer]
else:
	output = self.output_layer(output)

这样我们就从输入到输出都理顺啦,再看整个模型的结构图是不是就清晰了!

知道了各个部分的作用后,模块的定义其实就不难啦(只是比较长而已

# processing inputs
# embeddings
self.input_embeddings = MultiEmbedding(
	embedding_sizes=self.hparams.embedding_sizes,
	categorical_groups=self.hparams.categorical_groups,
	embedding_paddings=self.hparams.embedding_paddings,
	x_categoricals=self.hparams.x_categoricals,
	max_embedding_size=self.hparams.hidden_size,
)

# continuous variable processing
self.prescalers = nn.ModuleDict(
	{
		name: nn.Linear(1, self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size))
		for name in self.reals
	}
)

# variable selection
# variable selection for static variables
static_input_sizes = {
	name: self.input_embeddings.output_size[name] for name in self.hparams.static_categoricals
}
static_input_sizes.update(
	{
		name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
		for name in self.hparams.static_reals
	}
)
self.static_variable_selection = VariableSelectionNetwork(
	input_sizes=static_input_sizes,
	hidden_size=self.hparams.hidden_size,
	input_embedding_flags={name: True for name in self.hparams.static_categoricals},
	dropout=self.hparams.dropout,
	prescalers=self.prescalers,
)

# variable selection for encoder and decoder
encoder_input_sizes = {
	name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder
}
encoder_input_sizes.update(
	{
		name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
		for name in self.hparams.time_varying_reals_encoder
	}
)

decoder_input_sizes = {
	name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder
}
decoder_input_sizes.update(
	{
		name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
		for name in self.hparams.time_varying_reals_decoder
	}
)

# create single variable grns that are shared across decoder and encoder
if self.hparams.share_single_variable_networks:
	self.shared_single_variable_grns = nn.ModuleDict()
	for name, input_size in encoder_input_sizes.items():
		self.shared_single_variable_grns[name] = GatedResidualNetwork(
			input_size,
			min(input_size, self.hparams.hidden_size),
			self.hparams.hidden_size,
			self.hparams.dropout,
		)
	for name, input_size in decoder_input_sizes.items():
		if name not in self.shared_single_variable_grns:
			self.shared_single_variable_grns[name] = GatedResidualNetwork(
				input_size,
				min(input_size, self.hparams.hidden_size),
				self.hparams.hidden_size,
				self.hparams.dropout,
			)

self.encoder_variable_selection = VariableSelectionNetwork(
	input_sizes=encoder_input_sizes,
	hidden_size=self.hparams.hidden_size,
	input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_encoder},
	dropout=self.hparams.dropout,
	context_size=self.hparams.hidden_size,
	prescalers=self.prescalers,
	single_variable_grns={}
	if not self.hparams.share_single_variable_networks
	else self.shared_single_variable_grns,
)

self.decoder_variable_selection = VariableSelectionNetwork(
	input_sizes=decoder_input_sizes,
	hidden_size=self.hparams.hidden_size,
	input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_decoder},
	dropout=self.hparams.dropout,
	context_size=self.hparams.hidden_size,
	prescalers=self.prescalers,
	single_variable_grns={}
	if not self.hparams.share_single_variable_networks
	else self.shared_single_variable_grns,
)

# static encoders
# for variable selection
self.static_context_variable_selection = GatedResidualNetwork(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	output_size=self.hparams.hidden_size,
	dropout=self.hparams.dropout,
)

# for hidden state of the lstm
self.static_context_initial_hidden_lstm = GatedResidualNetwork(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	output_size=self.hparams.hidden_size,
	dropout=self.hparams.dropout,
)

# for cell state of the lstm
self.static_context_initial_cell_lstm = GatedResidualNetwork(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	output_size=self.hparams.hidden_size,
	dropout=self.hparams.dropout,
)

# for post lstm static enrichment
self.static_context_enrichment = GatedResidualNetwork(
	self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout
)

# lstm encoder (history) and decoder (future) for local processing
self.lstm_encoder = LSTM(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	num_layers=self.hparams.lstm_layers,
	dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
	batch_first=True,
)

self.lstm_decoder = LSTM(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	num_layers=self.hparams.lstm_layers,
	dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
	batch_first=True,
)

# skip connection for lstm
self.post_lstm_gate_encoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)
self.post_lstm_gate_decoder = self.post_lstm_gate_encoder
# self.post_lstm_gate_decoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)
self.post_lstm_add_norm_encoder = AddNorm(self.hparams.hidden_size, trainable_add=False)
# self.post_lstm_add_norm_decoder = AddNorm(self.hparams.hidden_size, trainable_add=True)
self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder

# static enrichment and processing past LSTM
self.static_enrichment = GatedResidualNetwork(
	input_size=self.hparams.hidden_size,
	hidden_size=self.hparams.hidden_size,
	output_size=self.hparams.hidden_size,
	dropout=self.hparams.dropout,
	context_size=self.hparams.hidden_size,
)

# attention for long-range processing
self.multihead_attn = InterpretableMultiHeadAttention(
	d_model=self.hparams.hidden_size, n_head=self.hparams.attention_head_size, dropout=self.hparams.dropout
)
self.post_attn_gate_norm = GateAddNorm(
	self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False
)
self.pos_wise_ff = GatedResidualNetwork(
	self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, dropout=self.hparams.dropout
)

# output processing -> no dropout at this late stage
self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False)

if self.n_targets > 1:  # if to run with multiple targets
	self.output_layer = nn.ModuleList(
		[nn.Linear(self.hparams.hidden_size, output_size) for output_size in self.hparams.output_size]
	)
else:
	self.output_layer = nn.Linear(self.hparams.hidden_size, self.hparams.output_size)

代码来源:pytorch_forecasting
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值