今天分享一篇来自Google的时序论文
Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting.
首先来谈谈一个好的时序模型需要考虑的东西可能有哪些:
- 单变量、多变量时间序列通吃
- 既能利用序列自身的历史信息,也能利用一些其他变量的信息(不仅仅是自回归咯
- 除了动态信息,静态信息也能利用,比如一些统计特征
- 能够适应复杂模式以及简单模式的时间序列
- 能够多步预测显然比单步预测再来递推来得好(避免累积误差
- 不仅是点预测,还能知道预测值的不确定区间
- Explainability is all we need!
模型结构一览
沉住气,先别被图吓到,真的不难
我们按照自底而上的角度先来粗看一下这个模型做的是什么
- Level 1 输入层
三部分:静态信息、历史(待预测变量)信息、未来(其他变量)信息
- Level 2 变量选择层
说白了就是要做特征筛选
- Level 3 LSTM编码层
既然是时间序列,LSTM来捕捉点长短期信息合情合理
- Level 4 Gate + Add&Norm
门控可以理解是在进一步考虑不同特征的重要性,残差和normalization常规操作了
- Level 5 GRN
跟Level4基本一样,可以理解就是在加深网络
- Level 6 Attention
对不同时刻的信息进行加权
- Level 7 输出层
做的是分位数回归,可以预测区间了
这样从输入到输出有个粗略了解后,接下对网络里一些关键模块进行解剖!
由于变量选择里其实用到了后面的GRN模块,所以我们从GRN开始
Gated Residual Network (GRN)
这里面可能让有些人不熟悉的应该只有ELU (Exponential Linear Unit) 和 GLU (Gated Linear Units)
- ELU这个激活函数可以取到负值,相比于Relu这让单元激活均值可以更接近0,类似于Batch Normalization的效果但是只需要更低的计算复杂度。同时在输入取较小值时具有软饱和的特性,提升了对噪声的鲁棒性。
- GLU做的时其实就是一个加权求和操作
ℎ(�)=(�∗�+�)⊗σ(�∗�+�)
输入X,W, b, V, c是需要学的参数,可以理解就是对输入做完仿射变换后,再进行加权,这个权重也是输入的仿射变换进行归一化(过一下sigmoid)
class GatedLinearUnit(nn.Module):
"""Gated Linear Unit"""
def __init__(self, input_size: int, hidden_size: int = None):
super().__init__()
self.hidden_size = hidden_size or input_size
self.fc = nn.Linear(input_size, self.hidden_size * 2)
def forward(self, x):
x = self.fc(x)
x = F.glu(x, dim=-1)
return x
- 残差和normalization
class AddNorm(nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.norm = nn.LayerNorm(self.input_size)
def forward(self, x: torch.Tensor, skip: torch.Tensor):
output = self.norm(x + skip)
return output
class GateAddNorm(nn.Module):
def __init__(self,input_size: int, hidden_size: int = None):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size or input_size
self.glu = GatedLinearUnit(self.input_size, hidden_size=self.hidden_size)
self.add_norm = AddNorm(self.hidden_size)
def forward(self, x, skip):
output = self.glu(x)
output = self.add_norm(output, skip)
return output
整个模块拼起来
class GatedResidualNetwork(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
context_size: int = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.context_size = context_size
self.hidden_size = hidden_size
self.fc1 = nn.Linear(self.input_size, self.hidden_size)
self.elu = nn.ELU()
if self.context_size is not None:
self.context = nn.Linear(self.context_size, self.hidden_size, bias=False)
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.gate_norm = GateAddNorm(
input_size=self.hidden_size,
hidden_size=self.output_size
)
def forward(self, x, context=None, residual=None):
x = self.fc1(x)
if context is not None:
context = self.context(context)
x = x + context
x = self.elu(x)
x = self.fc2(x)
x = self.gate_norm(x, residual)
return x
了解了GRN这个基础模块后看一下变量选择怎么做
Variable Selection Network (VSN)
这里的变量选择其实是一种soft的选择方式,并不是剔除不重要变量,而是对变量进行加权,权重越大的代表重要性越高。(可以理解跟注意力机制的意思差不多,只是换种计算注意力系数的方式)
这里考虑多变量情况,不单独考虑单变量情况
class VariableSelectionNetwork(nn.Module):
def __init__(
self,
input_sizes: Dict[str, int],
hidden_size: int,
input_embedding_flags: Dict[str, bool] = {},
context_size: int = None,
single_variable_grns: Dict[str, GatedResidualNetwork] = {},
prescalers: Dict[str, nn.Linear] = {},
):
"""
Calcualte weights for ``num_inputs`` variables which are each of size ``input_size``
"""
super().__init__()
self.hidden_size = hidden_size
self.input_sizes = input_sizes
self.input_embedding_flags = input_embedding_flags
self.context_size = context_size
if self.context_size is not None:
self.flattened_grn = GatedResidualNetwork(
self.input_size_total,
min(self.hidden_size, self.num_inputs),
self.num_inputs,
self.context_size,
)
else:
self.flattened_grn = GatedResidualNetwork(
self.input_size_total,
min(self.hidden_size, self.num_inputs),
self.num_inputs,
)
self.single_variable_grns = nn.ModuleDict()
self.prescalers = nn.ModuleDict()
for name, input_size in self.input_sizes.items():
if name in single_variable_grns:
self.single_variable_grns[name] = single_variable_grns[name]
elif self.input_embedding_flags.get(name, False):
self.single_variable_grns[name] = ResampleNorm(input_size, self.hidden_size)
else:
self.single_variable_grns[name] = GatedResidualNetwork(
input_size,
min(input_size, self.hidden_size),
output_size=self.hidden_size,
)
if name in prescalers: # reals need to be first scaled up
self.prescalers[name] = prescalers[name]
elif not self.input_embedding_flags.get(name, False):
self.prescalers[name] = nn.Linear(1, input_size)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
# transform single variables
var_outputs = []
weight_inputs = []
for name in self.input_sizes.keys():
# select embedding belonging to a single input
variable_embedding = x[name]
if name in self.prescalers:
variable_embedding = self.prescalers[name](variable_embedding)
weight_inputs.append(variable_embedding)
var_outputs.append(self.single_variable_grns[name](variable_embedding))
var_outputs = torch.stack(var_outputs, dim=-1)
# 计算权重
flat_embedding = torch.cat(weight_inputs, dim=-1)
sparse_weights = self.flattened_grn(flat_embedding, context)
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
outputs = var_outputs * sparse_weights # 加权和
outputs = outputs.sum(dim=-1)
return outputs, sparse_weights
- LSTM 编码器
这里调用Pytorch自带模块就行咯
self.lstm_encoder = LSTM(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
num_layers=self.hparams.lstm_layers,
dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
batch_first=True,
)
然后就是注意力机制的设计了
- Attention
文章对多头注意力机制的改进在于共享部分参数,即对于每一个head,Q和K都有分别的线性变换矩阵,但是V变换矩阵是共享的
class InterpretableMultiHeadAttention(nn.Module):
def __init__(self, n_head: int, d_model: int, dropout: float = 0.0):
super(InterpretableMultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_k = self.d_q = self.d_v = d_model // n_head
self.dropout = nn.Dropout(p=dropout)
self.v_layer = nn.Linear(self.d_model, self.d_v)
self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q) for _ in range(self.n_head)])
self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k) for _ in range(self.n_head)])
self.attention = ScaledDotProductAttention()
self.w_h = nn.Linear(self.d_v, self.d_model, bias=False)
def forward(self, q, k, v, mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
heads = []
attns = []
vs = self.v_layer(v) # 共享的
for i in range(self.n_head):
qs = self.q_layers[i](q)
ks = self.k_layers[i](k)
head, attn = self.attention(qs, ks, vs, mask)
head_dropout = self.dropout(head)
heads.append(head_dropout)
attns.append(attn)
head = torch.stack(heads, dim=2) if self.n_head > 1 else heads[0]
attn = torch.stack(attns, dim=2)
outputs = torch.mean(head, dim=2) if self.n_head > 1 else head
outputs = self.w_h(outputs)
outputs = self.dropout(outputs)
return outputs, attn
其中注意力机制的计算还是标准的ScaledDotProductAttention,用QK计算出注意力系数,然后再来对V加权一下
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout: float = None, scale: bool = True):
super(ScaledDotProductAttention, self).__init__()
if dropout is not None:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = dropout
self.softmax = nn.Softmax(dim=2)
self.scale = scale
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.permute(0, 2, 1)) # query-key overlap
if self.scale:
dimension = torch.as_tensor(k.size(-1), dtype=attn.dtype, device=attn.device).sqrt()
attn = attn / dimension
if mask is not None:
attn = attn.masked_fill(mask, -1e9)
attn = self.softmax(attn)
if self.dropout is not None:
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
分位数分归
换个损失函数就是了
分位数损失的定义 ���(�∗(�−�����),(1−�)∗(�����−�))
def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# calculate quantile loss
self.quantiles = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
losses = []
for i, q in enumerate(self.quantiles):
errors = target - y_pred[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = torch.cat(losses, dim=2)
当然要用MSE也没问题的
了解清楚各个模块的功能后,那就开始搭积木了
模型汇总
为了顺着模型结构走,我们从forward顺下来看看
- 首先是对输入做一些处理才能送入模型呀(不理解它的处理也没关系,根据自己的数据来就好了)
对应的就是
encoder_lengths = x["encoder_lengths"]
decoder_lengths = x["decoder_lengths"]
x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension
x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension
timesteps = x_cont.size(1) # encode + decode length
max_encoder_length = int(encoder_lengths.max())
input_vectors = self.input_embeddings(x_cat)
input_vectors.update(
{
name: x_cont[..., idx].unsqueeze(-1)
for idx, name in enumerate(self.hparams.x_reals)
if name in self.reals
}
)
- 接下来就是变量选择了(会需要一些embedding的转换
# Embedding and variable selection
if len(self.static_variables) > 0:
# static embeddings will be constant over entire batch
static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables}
static_embedding, static_variable_selection = self.static_variable_selection(static_embedding)
else:
static_embedding = torch.zeros(
(x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype, device=self.device
)
static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype, device=self.device)
static_context_variable_selection = self.expand_static_context(
self.static_context_variable_selection(static_embedding), timesteps
)
embeddings_varying_encoder = {
name: input_vectors[name][:, :max_encoder_length] for name in self.encoder_variables
}
embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(
embeddings_varying_encoder,
static_context_variable_selection[:, :max_encoder_length],
)
embeddings_varying_decoder = {
name: input_vectors[name][:, max_encoder_length:] for name in self.decoder_variables # select decoder
}
embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(
embeddings_varying_decoder,
static_context_variable_selection[:, max_encoder_length:],
)
- 再来是LSTM编码,注意这边采用静态信息来初始化LSTM
# LSTM
# calculate initial state
input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand(
self.hparams.lstm_layers, -1, -1
)
input_cell = self.static_context_initial_cell_lstm(static_embedding).expand(self.hparams.lstm_layers, -1, -1)
# run local encoder
encoder_output, (hidden, cell) = self.lstm_encoder(
embeddings_varying_encoder, (input_hidden, input_cell), lengths=encoder_lengths, enforce_sorted=False
)
# run local decoder
decoder_output, _ = self.lstm_decoder(
embeddings_varying_decoder,
(hidden, cell),
lengths=decoder_lengths,
enforce_sorted=False,
)
- 残差连接(红色箭头
# skip connection over lstm
lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output)
lstm_output_encoder = self.post_lstm_add_norm_encoder(lstm_output_encoder, embeddings_varying_encoder)
lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output)
lstm_output_decoder = self.post_lstm_add_norm_decoder(lstm_output_decoder, embeddings_varying_decoder)
lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1)
- 进入attention前先进行一下静态信息的增强
# static enrichment
static_context_enrichment = self.static_context_enrichment(static_embedding)
attn_input = self.static_enrichment(
lstm_output, self.expand_static_context(static_context_enrichment, timesteps)
)
- 然后计算多头注意力,同样它也有残差连接,然后再经过一个GRN输出
# Attention
attn_output, attn_output_weights = self.multihead_attn(
q=attn_input[:, max_encoder_length:], # query only for predictions
k=attn_input,
v=attn_input,
mask=self.get_attention_mask(
encoder_lengths=encoder_lengths, decoder_length=timesteps - max_encoder_length
),
)
# skip connection over attention
attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:])
output = self.pos_wise_ff(attn_output)
- 最后可以看到还有个比较远的残差连接,然后经过Dense层输出
output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:])
if self.n_targets > 1: # if to use multi-target architecture
output = [output_layer(output) for output_layer in self.output_layer]
else:
output = self.output_layer(output)
这样我们就从输入到输出都理顺啦,再看整个模型的结构图是不是就清晰了!
知道了各个部分的作用后,模块的定义其实就不难啦(只是比较长而已
# processing inputs
# embeddings
self.input_embeddings = MultiEmbedding(
embedding_sizes=self.hparams.embedding_sizes,
categorical_groups=self.hparams.categorical_groups,
embedding_paddings=self.hparams.embedding_paddings,
x_categoricals=self.hparams.x_categoricals,
max_embedding_size=self.hparams.hidden_size,
)
# continuous variable processing
self.prescalers = nn.ModuleDict(
{
name: nn.Linear(1, self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size))
for name in self.reals
}
)
# variable selection
# variable selection for static variables
static_input_sizes = {
name: self.input_embeddings.output_size[name] for name in self.hparams.static_categoricals
}
static_input_sizes.update(
{
name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
for name in self.hparams.static_reals
}
)
self.static_variable_selection = VariableSelectionNetwork(
input_sizes=static_input_sizes,
hidden_size=self.hparams.hidden_size,
input_embedding_flags={name: True for name in self.hparams.static_categoricals},
dropout=self.hparams.dropout,
prescalers=self.prescalers,
)
# variable selection for encoder and decoder
encoder_input_sizes = {
name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder
}
encoder_input_sizes.update(
{
name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
for name in self.hparams.time_varying_reals_encoder
}
)
decoder_input_sizes = {
name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder
}
decoder_input_sizes.update(
{
name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)
for name in self.hparams.time_varying_reals_decoder
}
)
# create single variable grns that are shared across decoder and encoder
if self.hparams.share_single_variable_networks:
self.shared_single_variable_grns = nn.ModuleDict()
for name, input_size in encoder_input_sizes.items():
self.shared_single_variable_grns[name] = GatedResidualNetwork(
input_size,
min(input_size, self.hparams.hidden_size),
self.hparams.hidden_size,
self.hparams.dropout,
)
for name, input_size in decoder_input_sizes.items():
if name not in self.shared_single_variable_grns:
self.shared_single_variable_grns[name] = GatedResidualNetwork(
input_size,
min(input_size, self.hparams.hidden_size),
self.hparams.hidden_size,
self.hparams.dropout,
)
self.encoder_variable_selection = VariableSelectionNetwork(
input_sizes=encoder_input_sizes,
hidden_size=self.hparams.hidden_size,
input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_encoder},
dropout=self.hparams.dropout,
context_size=self.hparams.hidden_size,
prescalers=self.prescalers,
single_variable_grns={}
if not self.hparams.share_single_variable_networks
else self.shared_single_variable_grns,
)
self.decoder_variable_selection = VariableSelectionNetwork(
input_sizes=decoder_input_sizes,
hidden_size=self.hparams.hidden_size,
input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_decoder},
dropout=self.hparams.dropout,
context_size=self.hparams.hidden_size,
prescalers=self.prescalers,
single_variable_grns={}
if not self.hparams.share_single_variable_networks
else self.shared_single_variable_grns,
)
# static encoders
# for variable selection
self.static_context_variable_selection = GatedResidualNetwork(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
output_size=self.hparams.hidden_size,
dropout=self.hparams.dropout,
)
# for hidden state of the lstm
self.static_context_initial_hidden_lstm = GatedResidualNetwork(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
output_size=self.hparams.hidden_size,
dropout=self.hparams.dropout,
)
# for cell state of the lstm
self.static_context_initial_cell_lstm = GatedResidualNetwork(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
output_size=self.hparams.hidden_size,
dropout=self.hparams.dropout,
)
# for post lstm static enrichment
self.static_context_enrichment = GatedResidualNetwork(
self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout
)
# lstm encoder (history) and decoder (future) for local processing
self.lstm_encoder = LSTM(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
num_layers=self.hparams.lstm_layers,
dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
batch_first=True,
)
self.lstm_decoder = LSTM(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
num_layers=self.hparams.lstm_layers,
dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0,
batch_first=True,
)
# skip connection for lstm
self.post_lstm_gate_encoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)
self.post_lstm_gate_decoder = self.post_lstm_gate_encoder
# self.post_lstm_gate_decoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout)
self.post_lstm_add_norm_encoder = AddNorm(self.hparams.hidden_size, trainable_add=False)
# self.post_lstm_add_norm_decoder = AddNorm(self.hparams.hidden_size, trainable_add=True)
self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder
# static enrichment and processing past LSTM
self.static_enrichment = GatedResidualNetwork(
input_size=self.hparams.hidden_size,
hidden_size=self.hparams.hidden_size,
output_size=self.hparams.hidden_size,
dropout=self.hparams.dropout,
context_size=self.hparams.hidden_size,
)
# attention for long-range processing
self.multihead_attn = InterpretableMultiHeadAttention(
d_model=self.hparams.hidden_size, n_head=self.hparams.attention_head_size, dropout=self.hparams.dropout
)
self.post_attn_gate_norm = GateAddNorm(
self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False
)
self.pos_wise_ff = GatedResidualNetwork(
self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, dropout=self.hparams.dropout
)
# output processing -> no dropout at this late stage
self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False)
if self.n_targets > 1: # if to run with multiple targets
self.output_layer = nn.ModuleList(
[nn.Linear(self.hparams.hidden_size, output_size) for output_size in self.hparams.output_size]
)
else:
self.output_layer = nn.Linear(self.hparams.hidden_size, self.hparams.output_size)
代码来源:pytorch_forecasting