文章目录
论文详情
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting
出版情况:AAAI2021 Best Paper
Background
长时间序列预测要求模型能够更好地捕捉输入和输出之间的长距离依赖关系。
尽管LSTM是很好的工具(无法否认),但是就现有的情况,已经无法满足长序列的需求。MSE更是会随着预测长度的增加显著增加。
Transformer本身能力是很强,但是时间复杂度和空间复杂度会随着输入序列的增加而急剧变化。同时,Transformer在NLP上的显著成绩大多是在dozens of GPUs上得到的,这在真实世界中是无法承受的。这篇论文想要解决的问题同上一篇论文其实是相似的,降低Transformer的计算需求,并改进架构,提升效率,使其适应长时间序列预测任务的需求,同时保持准确率。
上面三点是论文中列出来的问题,计算时间复杂度、memory利用率和推理速度。
论文提出了三点内容:
- ProbSparse,在时间复杂度和空间复杂度上都有所降低
- self-attention distilling operation,帮助更好地处理长距离输入
- 改进了生成器,可一步输出结果
Method
模型的整体架构图,左边是编码器部分。绿色的是长序列输入;蓝色的部分代表蒸馏操作,可以降低模型规模。右边是解码器部分,对于要预测的部分,用0补全,生成的部分用橙色进行表示。
注意力机制改进
在量化实验中,注意力分数具有长尾效应,即小部分的点乘对在分数上占据了更大的比重:
在计算注意力得分的过程中,时间复杂度和空间复杂度都会随输入长度而变化增加。先前的研究在理论上受限于启发式规则的影响,针对multi-head self-attention也都采取相同的策略,这就限制了提升的可能。
根据论文Tsai et al. 2019的定义,
i
i
i-th query的注意力可以被视为概率的形式:
A
(
q
i
,
K
,
V
)
=
∑
j
k
(
q
i
,
k
j
)
∑
l
k
(
q
i
,
k
l
)
v
j
=
E
p
(
k
j
∣
q
i
)
[
v
j
]
\mathbb{A}(q_i,K,V)=\sum_j\frac{\mathscr{k}(q_i,k_j)}{\sum_l\mathscr{k}(q_i,k_l)}v_j=\mathbb{E}_{p(k_j|q_i)}[v_j]
A(qi,K,V)=j∑∑lk(qi,kl)k(qi,kj)vj=Ep(kj∣qi)[vj]
其中
p
(
k
i
∣
q
i
)
=
k
(
q
i
,
k
j
)
/
∑
l
k
(
q
i
,
k
l
)
\mathscr{p}(k_i|q_i)=\mathbb{k}(q_i,k_j)/\sum_l\mathbb{k}(q_i,k_l)
p(ki∣qi)=k(qi,kj)/∑lk(qi,kl)。
占优势的点积对就会促使相应的query的分布偏离均匀分布;相反,如果接近均匀分布 q ( k j ∣ q i ) = 1 / L K \mathscr{q}(k_j|q_i)=1/L_K q(kj∣qi)=1/LK,那对于整体而言的重要性并不高。
从直觉上,两个分布之间是能够辨别出哪些是重要的点积对,哪些是不重要的点积对,因此作者们采用KL散度,来评判
i
i
i-th query的离散程度:
K
L
(
q
∣
∣
p
)
≈
M
(
q
i
,
K
)
=
ln
∑
j
=
1
L
K
e
q
i
k
j
T
d
−
1
L
K
∑
j
=
1
L
K
q
i
k
j
T
d
KL(\mathscr{q}||\mathscr{p})≈M(q_i,K)=\ln{\sum^{L_K}_{j=1}e^{\frac{q_ik_j^T}{\sqrt{d}}}}-\frac{1}{L_K}\sum^{L_K}_{j=1}\frac{q_ik_j^T}{\sqrt{d}}
KL(q∣∣p)≈M(qi,K)=lnj=1∑LKedqikjT−LK1j=1∑LKdqikjT
公式的第一项是
q
i
q_i
qi对所有key值得Log-Sum-Exp,第二项则是算术平均值
M
(
q
i
,
K
)
M(q_i,K)
M(qi,K)越大,证明
p
\mathscr{p}
p就越偏离平均值,也就有更高得概率被包含在长尾分布中的头部。
但其实到这一步为止,计算的时间复杂度依然是 O ( L K L Q ) O(L_KL_Q) O(LKLQ),并没有降低。
之后的内容,其实理解并没有那么深刻,但是结合其它文章的分析(参考1,参考2)以及附录中的内容,大概能知道是在得到KL的上下界之后,确定KL的估计 max j { q i k j T d } − 1 L K ∑ j = 1 L K q i k j T d \max_j\{\frac{q_ik_j^T}{\sqrt{d}}\}-\frac{1}{L_K}\sum^{L_K}_{j=1}\frac{q_ik_j^T}{\sqrt{d}} maxj{dqikjT}−LK1∑j=1LKdqikjT,然后采取Top-u的策略。大致步骤如下:
输入: Q ∈ R m × d , K ∈ R n × d , V ∈ R m × d Q\in\mathbb{R}^{m\times d},K\in\mathbb{R}^{n\times d},V\in\mathbb{R}^{m\times d} Q∈Rm×d,K∈Rn×d,V∈Rm×d
超参: c , u = c ln m , U = m ln n c, u=c\ln{m},U=m\ln{n} c,u=clnm,U=mlnn
流程:
- 从 K K K中随机采样 U U U个点积对,构成集合 K ˉ \bar{K} Kˉ
- 计算采样分数集 S ˉ = Q K ˉ T \bar{S}=Q\bar{K}^T Sˉ=QKˉT
- 根据量化标准 M = max ( S ˉ ) − m e a n ( S ˉ ) M=\max{(\bar{S})}-mean(\bar{S}) M=max(Sˉ)−mean(Sˉ),对 S ˉ \bar{S} Sˉ进行计算
- 在M标准下,选择Top-u作为 Q ˉ \bar{Q} Qˉ
- S 1 = s o f t m a x ( Q ˉ K T / d ) ⋅ V , S 0 = m e a n ( V ) S_1=softmax(\bar{Q}K^T/\sqrt{d})\cdot V,S_0=mean(V) S1=softmax(QˉKT/d)⋅V,S0=mean(V),根据它们原来的行数,分别设置 S 1 , S 0 S_1,S_0 S1,S0
- 输出结果。
如上,整体的时间复杂度就能降至 O ( L K ln L Q ) O(L_K\ln{L_Q}) O(LKlnLQ)。
编码器结构
编码器上,增加了
M
a
x
P
o
o
l
MaxPool
MaxPool的操作,实现蒸馏。
为了实现鲁棒性和之后的维度对齐,作者们又增加了一个编码器的copy版本,输入1/2的长度,只经过一层attention block和蒸馏,得到1/4的长度输出。
但是编码器的结构这段,感觉依然不是很明白,大概查阅了下代码,也没有弄得很清楚,之后跑起来看看。
解码器结构
原文中提到的是使用了两层标准的Transformer架构:
上面的
X
t
o
k
e
n
t
X^t_{token}
Xtokent作为start token,而
X
0
t
X^t_0
X0t是待生成的内容。
作者在论文中举了个例子,比如要预测168个点的内容,那么
X
t
o
k
e
n
t
X^t_{token}
Xtokent就是已知的5天的数据,而
X
0
t
X^t_0
X0t则是需要预测的内容,
X
d
e
=
{
X
5
d
,
X
0
}
X_{de}=\{X_{5d},X_0\}
Xde={X5d,X0}。然后进行一次性预测(这里可能需要之后结合代码看,但是也可以参考其它的博文,我打算先看看代码)。
Experiment
Univariate Time-Series Forecasting
Informer是本文提出的方法,第二列的方法是没有使用sparse attention的方法。Informer在各个数据集上的表现都很好,但是和第二列的相比,似乎并没有很显著很显著的提升,从数据上看,基本上是伯仲之间。
Multivariate Time-Series Forecasting
Parameter Sensitivity
输入长度大小对结果的影响。在预测短序列时,增加输入长度,会降低想过,但是当预测长度持续增长,表现会有所提升。这是因为更长的输入会带来更多重复的短期pattern。对编码器而言,能够捕获到更多的依赖关系;对解码器而言,能够有更丰富的局部信息。
针对ProbSparse中采样因子的测试,实验证实了他们自己的假设,即self-attention中确实存在多余的点积对。
关于这一段的实验和编码器部分一样,等之后理解了再补充。
Ablation Study
下面一行的是移除蒸馏操作的模型。在实验效果上,和具有蒸馏操作的互有胜负,可是问题在于当输入长度超过720之后,就会出现OOM的现象,显存炸了。
该实验则证明了一步式生成的有效性。能更好地捕捉长距离之间的关系。
Conclusion
Informer感觉是一个很有意思的工作,感觉主要是自己菜,所以在这篇论文上花了很多的时间。内有大量的公式证明,发现还是自己的基础不够扎实,能够囫囵地知道为什么这么做,但是再具体一点儿,就分析不下去了。实验部分也很充分。作者在开头的时候提出的三个问题:1、self-attention的计算复杂度过高;2、存储瓶颈;3、推理速度过慢。在文中也都给出了有效的解决方法:1、引入sparse attention;2、引入蒸馏操作;3、Generative style decoder,一步得到输出结果。
LogSparse那篇论文也很有意思,在他们自己的实验中,其实效果提升并不明显(提升看起来不是特别的多),但是在这篇论文中,能明显感觉到LogSparse其实是一个绕不开的baseline。
Informer无疑很强,是一个值得花时间之后再继续琢磨的工作。希望之后有能力再好好理解这个工作。