Content
此篇论文主要完成了:
1.通过数学推导,找到并证明了限制RNN-based LMs的性能瓶颈之一——Softmax Bottleneck问题
2.针对这个瓶颈,提出了一个解决方案—— Mixture of Softmaxes
Introduction
Language Modeling Problem
LM问题在指已知了一个符号(token)序列:
X
=
(
X
1
,
.
.
.
,
X
T
)
X=(X_1,...,X_T)
X=(X1,...,XT)
的情况下,生成一个Language模型来模拟这个序列出现的概率,即求解P(X)的值:
P
(
X
)
=
∏
t
P
(
X
t
∣
X
<
T
)
P(X)=\prod_tP(X_t|X_{<T})
P(X)=t∏P(Xt∣X<T)
而根据链式法则(chain rule)和马尔科夫假设(Markov Assumption), P(X)的值可以通过求解它对应的联合概率得出:
P
(
X
)
=
∏
t
P
(
X
t
∣
X
<
T
)
=
∏
P
(
X
t
∣
C
t
)
P(X)=\prod_tP(X_t|X_{<T})=\prod{P(X_t|C_t)}
P(X)=t∏P(Xt∣X<T)=∏P(Xt∣Ct)
X
t
X_t
Xt: 下一个符号的概率分布
C
t
C_t
Ct:历史序列 / 已经出现的所有Tokens
因此,原始的LM问题就转换成了:在每个时刻t,根据已知的符号序列 C t C_t Ct(History), 求解下一时刻可能输出的符号的概率。 由于输出符号有很多种而可能性,所以这个概率的其实是一个在Vocabulary(或Token Set)上的概率分布。
Standard Approach: RNN based LMs
由于符号序列自带时间属性,而我们需要模拟的也是符号之间的时间依赖关系。因此对于LM问题来说,最标准的,而且state of art 模型都是基于RNN的。以下为一个基于RNN的Language Model的结构简图:
其中:
h
t
=
σ
(
V
h
t
−
1
+
U
x
t
)
h_t=\sigma(Vh_{t-1}+Ux_t)
ht=σ(Vht−1+Uxt)
o
=
W
h
t
o=Wh_t
o=Wht
P
(
x
t
∣
c
t
)
=
y
t
=
e
o
(
x
t
)
∑
v
e
o
(
v
)
P(x_t|c_t)=y_t=\frac{e^{o(x_t)}}{\sum_ve^{o(v)}}
P(xt∣ct)=yt=∑veo(v)eo(xt)
首先图片左下角的符号序列 " the cat sat on the " 是我们在此时刻 t 已知的历史序列,即
C
t
C_t
Ct 。由于输入的每个Token是由one-hot编码表示的,当Vocabulary很大的情况下,这个输入维度会非常的高。因此在处理这种高维输入时,会先使用word embedding matrix ( 图中矩阵U ) 来降低维度&学习词语的内部联系,使输入更有意义。
之后,经过处理的输入会被传送给RNN。基于此时刻 t 的输入和上一时刻RNN的旧隐藏状态 $ h_{t-1} $ , RNN会产生新的隐藏状态 $ h_t $ 。
此隐藏状态可能看作是一个由RNN学习到的,含有下一时刻输出符号的信息的特征。由于输出是一个基于Vocabulary的概率分布,因此我们必须把学习到的这个特征映射回初始的Vocabulary。上图中,
h
t
h_t
ht下的output embedding matrix (W) 就负责这个反映射。应用上,两个embedding 矩阵 U和W是一样的。
###Hypothesis&Main Issues
由Anton Maximilian Schäfer and Hans Georg Zimmermann写的Recurrent Neural Networks Are Universal Approximators论文可以得知,RNN的表达力是很强的,它可以模拟逼近任意的非线性动态系统(Universal approximation theorem)。由此作者推测出,基于RNN的LMs的性能瓶颈之一应该是RNN最后使用点乘+softmax操作,即:$o=Wh_t ; out = softmax(o) $。
Mathematical Analysis of LM
Defination
为了能进行数学推导和定量分析来证明这个假设,首先我们需要一个自然语言的数学表达。自然语言L可以表示成N个元组的集合:
L
=
{
(
c
1
,
P
∗
(
X
∣
c
1
)
)
,
.
.
.
,
(
c
N
,
P
∗
(
X
∣
c
N
)
)
}
L=\{(c_1,P^*(X|c_1)),...,(c_N,P^*(X|c_N))\}
L={(c1,P∗(X∣c1)),...,(cN,P∗(X∣cN))}
其中:
c
i
:
c_i:
ci:代表了语言中的任一个可能的context(history token序列)
P
∗
(
X
∣
c
i
)
:
P^*(X|c_i):
P∗(X∣ci):真实的数据分布,即:已知一个历史符号序列(
c
i
c_i
ci),下一符号在Token集合
X
X
X上的概率分布
X
=
{
x
1
,
x
2
,
.
.
.
,
x
M
}
:
X=\{x_1,x_2,...,x_M\}:
X={x1,x2,...,xM}: 代表了语言L中所有可能出现的符号
N
:
N:
N: 所有可能的上下文(符号组合)的数目
至此,LM问题可以转换成如下的数学公式表达:
P
θ
(
X
∣
c
)
=
P
∗
(
X
∣
c
)
P_\theta(X|c)=P^*(X|c)
Pθ(X∣c)=P∗(X∣c)
即,给定一个自然语言L,LM需要学习一组参数
θ
\theta
θ,基于此组参数的模型可以逼近真实的任一上下文(context)所对应的下一符号概率分布。
若我们使用RNN-based LMs, 那么在network的输出端,我们能从softmax layer 的输出直接得到基于此时刻 t 的下一符号概率分布
P
θ
(
X
∣
c
)
P_{\theta}(X|c)
Pθ(X∣c) :
P
θ
(
X
∣
c
)
=
e
x
p
(
h
c
T
w
x
)
∑
x
e
x
p
(
h
c
T
w
x
)
P_\theta(X|c)=\frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}
Pθ(X∣c)=∑xexp(hcTwx)exp(hcTwx)
因此,训练模型的Objective可以用以下等式表达:
P
θ
(
X
∣
c
)
=
e
x
p
(
h
c
T
w
x
)
∑
x
e
x
p
(
h
c
T
w
x
)
=
P
∗
(
X
∣
c
)
P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
Pθ(X∣c)=∑xexp(hcTwx)exp(hcTwx)=P∗(X∣c)
即,我们使用一个RNN-based LM 来模拟每个可能context下的下一符号概率分布,并且不断优化模型使用的参数
θ
\theta
θ,使LM输出的概率分布逼近真实分布。
Matrix Factorization Problem
在数学化表达LM问题后,它的Objective公式还可以通过矩阵分解来做进一步的分析。
在
P
θ
(
X
∣
c
)
P_{\theta}(X|c)
Pθ(X∣c)的表达式中,
h
c
T
h_c^T
hcT代表了输入是不同的context(历史序列)的情况下,RNN所对应的不同隐藏状态。此处,可以把所有可能的情况列出,排列组合成一个矩阵:
H
θ
=
[
h
c
1
T
h
c
2
T
.
.
.
h
c
N
T
]
H_{\theta}=\left[ \begin{matrix} h^T_{c_1} \\ h^T_{c_2} \\ ... \\ h^T_{c_N} \end{matrix} \right]
Hθ=⎣⎢⎢⎡hc1Thc2T...hcNT⎦⎥⎥⎤
这个矩阵包含了RNN针对不同的Context的所有可能的隐藏状态。根据此节开篇的假设,自然语言L一共有
N
N
N种可能的context(即:符号组合序列)。
相似地,公式中的
w
x
w_x
wx也可以统一成矩阵表达:
W
θ
=
[
w
x
1
T
w
x
2
T
.
.
.
w
x
M
T
]
W_{\theta}= \left[ \begin{matrix} w^T_{x_1} \\ w^T_{x_2} \\ ... \\ w^T_{x_M} \end{matrix} \right]
Wθ=⎣⎢⎢⎡wx1Twx2T...wxMT⎦⎥⎥⎤
W
θ
W_{\theta}
Wθ中的每一行代表了语言L中的某一个符号
x
i
x_i
xi所对应的embedding coefficient,用以把RNN学到的隐藏状态映射回包含
X
X
X符号集的Vocabulary空间。同样,根据此节开篇的假设,自然语言L一共有
M
M
M种可能的符号(tokens)。
最后,我们还需要把自然语言L真实的条件概率分布(在各种可能的context下,下一符号的概率分布)用矩阵的方式表达,从而能使用矩阵知识,数学地分析RNN-based LMs。此处假设矩阵
A
A
A代表了真实条件概率分布
P
∗
(
X
∣
c
)
P^*(X|c)
P∗(X∣c)取
log
\log
log后的结果:
A
=
[
log
P
∗
(
x
1
∣
c
1
)
log
P
∗
(
x
2
∣
c
1
)
.
.
.
log
P
∗
(
x
M
∣
c
1
)
log
P
∗
(
x
1
∣
c
2
)
log
P
∗
(
x
2
∣
c
2
)
.
.
.
log
P
∗
(
x
M
∣
c
2
)
.
.
.
.
.
.
.
.
.
.
.
.
log
P
∗
(
x
1
∣
c
N
)
log
P
∗
(
x
2
∣
c
N
)
.
.
.
log
P
∗
(
x
M
∣
c
N
)
]
A= \left[ \begin{matrix} \log{P^*(x_1|c_1)} &\log{P^*(x_2|c_1)}&...&\log{P^*(x_M|c_1)}\\ \log{P^*(x_1|c_2)} &\log{P^*(x_2|c_2)}&...&\log{P^*(x_M|c_2)} \\ ...&...&...&... \\ \log{P^*(x_1|c_N)} &\log{P^*(x_2|c_N)}&...&\log{P^*(x_M|c_N)} \end{matrix} \right]
A=⎣⎢⎢⎡logP∗(x1∣c1)logP∗(x1∣c2)...logP∗(x1∣cN)logP∗(x2∣c1)logP∗(x2∣c2)...logP∗(x2∣cN)............logP∗(xM∣c1)logP∗(xM∣c2)...logP∗(xM∣cN)⎦⎥⎥⎤
由上公式可知,
A
A
A包含了context与对应next token的所有可能的组合。
Rank Analysis
在经历了上述对Objective的分析及矩阵转换,RNN-based LM问题事实上可以抽象如下:
∃
θ
,
log
(
S
o
f
t
m
a
x
(
H
θ
W
θ
T
)
)
=
A
\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
∃θ,log(Softmax(HθWθT))=A
即,通过学习,我们希望找到一组参数
θ
\theta
θ,以它为参数的LM模型(即RNN)可以逼近真实的下一符号概率分布的
log
\log
log。
为了能推导出Softmax存在的瓶颈,首先先要引入一个矩阵操作
r
o
w
−
w
i
s
e
s
h
i
f
t
row-wise\ shift
row−wise shift。对一个矩阵 A 进行
r
o
w
−
w
i
s
e
s
h
i
f
t
row-wise\ shift
row−wise shift操作,其结果为一个矩阵集合
F
(
A
)
:
F(A):
F(A):
F
(
A
)
=
{
A
+
Λ
J
N
,
M
∣
Λ
i
s
d
i
a
g
o
n
a
l
a
n
d
R
N
×
N
}
F(A)=\{ A+\Lambda J_{N,M}| \Lambda \ is\ diagonal\ and \ R^{{N}\times{N}}\}
F(A)={A+ΛJN,M∣Λ is diagonal and RN×N}
其中:
J
N
,
M
:
J_{N,M}:
JN,M:维度对应的全1矩阵
Λ
:
\Lambda:
Λ:对角线元素值任意的对角线矩阵
事实上,
r
o
w
−
w
i
s
e
s
h
i
f
t
row-wise\ shift
row−wise shift的作用是把矩阵
A
A
A中的每行元素上加上任意一个实数,例如如下
Λ
J
N
,
M
\Lambda J_{N,M}
ΛJN,M与矩阵
A
A
A相加后,
A
A
A的第 i 行会被加上一个实数
a
i
a_i
ai。
[
a
1
0
0
0
a
2
0
0
0
a
3
]
Λ
3
×
3
×
[
1
1
1
1
1
1
1
1
1
]
J
3
×
4
=
[
a
1
a
1
a
1
a
2
a
2
a
2
a
3
a
3
a
3
]
\left[ \begin{matrix} a_1&0&0 \\ 0&a_2&0 \\ 0&0&a_3 \end{matrix} \right]_{\Lambda^{3\times3} } \times{\left[ \begin{matrix} 1&1&1 \\ 1&1&1 \\ 1&1&1 \end{matrix} \right]_{J^{3\times4}}}={\left[ \begin{matrix} a_1&a_1&a_1 \\ a_2&a_2&a_2 \\ a_3&a_3&a_3 \end{matrix} \right]}
⎣⎡a1000a2000a3⎦⎤Λ3×3×⎣⎡111111111⎦⎤J3×4=⎣⎡a1a2a3a1a2a3a1a2a3⎦⎤
而代表真实下一符号概率分布的
log
\log
log 矩阵
A
A
A 与
A
A
A 经由
r
o
w
−
w
i
s
e
s
h
i
f
t
row-wise\ shift
row−wise shift 所得到的矩阵集合
F
(
A
)
F(A)
F(A) ,有如下两个特殊性质:
1.所有真实数据分布所对应的logits都包含在了集合
F
(
A
)
F(A)
F(A)中。
2.
F
(
A
)
F(A)
F(A) 中的所有矩阵的秩
都相似,相差不大于1。
附–矩阵的秩 :
-定义: 矩阵中所有线性独立的列的数目和
-直观解释:如果一个矩阵有着更高的秩,那么说明它有更多的线性独立的列。若把这些列看作是一组 basis vectors ,那么它们所能表达的空间就更复杂,表达能力就更强。即,高秩的矩阵能包含更多的信息量。
-例子:如果我们把某自然语言L表示成矩阵形式(如上节中的矩阵
A
A
A),那么此矩阵
A
A
A天然拥有高秩的性质,例如:
-它是高度依赖上下文的——“南”后面的符号可以是“京”或者“瓜”,取决于前后文是关于地理的还是农业的。即,在不同的上下文里,下一符号的概率分布会非常不同。
-并且我们不可能找到一组有限数目的basis vectors,使用此基来表达语言L中的所有Token的关系。
review
由RNN-based LM的结构推导出,它的Objective如下:
P
θ
(
X
∣
c
)
=
e
x
p
(
h
c
T
w
x
)
∑
x
e
x
p
(
h
c
T
w
x
)
=
P
∗
(
X
∣
c
)
P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
Pθ(X∣c)=∑xexp(hcTwx)exp(hcTwx)=P∗(X∣c)
通过把自然语言表达成矩阵形式,再进行矩阵分解(Matrix Factorization ),LM的目标可以抽象成如下表达。即,LM需要找到一组参数,借由这组参数生成的下一符号概率能无限逼近真实概率:
∃
θ
,
log
(
S
o
f
t
m
a
x
(
H
θ
W
θ
T
)
)
=
A
\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
∃θ,log(Softmax(HθWθT))=A
而通过引入矩阵运算符 row-wise shift ,以及此运算产生的矩阵集F(A)的第一个性质,我们可以推出,若RNN-based LM真的能逼近真实概率分布,那么它产生的 logits 必定属于真实概率分布矩阵 A 的 row-wise shift 结果集合中。即,Objective为如下:
∃
θ
,
s
u
c
h
t
h
a
t
,
H
θ
W
θ
T
∈
F
(
A
)
\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}
∃θ,such that,HθWθT∈F(A)
Problem: Softmax Bottleneck
至此,LM问题的核心变成了研究是否真的存在一组参数
θ
,
\theta,
θ,使基于此
θ
\theta
θ的LM所产生的logits属于
F
(
A
)
F(A)
F(A) ,如下:
∃
θ
,
s
u
c
h
t
h
a
t
,
H
θ
W
θ
T
∈
F
(
A
)
\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}
∃θ,such that,HθWθT∈F(A)
回忆一下,如上公式中:
H
θ
∈
R
N
×
d
,
H_{\theta}\in{R^{N\times{d}}},
Hθ∈RN×d,代表了所有可能的context输入下的对应隐藏状态。
W
θ
T
∈
R
M
×
d
,
W^T_{\theta}\in{R^{M\times{d}}},
WθT∈RM×d,代表了语言中所有可能的token所对应embedding coefficient。
因此,由线性代数的知识可知,它们乘积的秩应该小于d,即:
r
a
n
k
(
H
θ
W
θ
T
)
≤
d
rank(H_{\theta}W^T_{\theta})\leq{d}
rank(HθWθT)≤d
(相较于自然语言中的context数目N和token数目M,embedding size d显然会小很多)
又由于row-wise shift的第二个性质(即:
F
(
A
)
F(A)
F(A)中的所有矩阵的秩都相似,相差不大与1)可推导出,若embedding size d有:
d
<
m
i
n
A
′
∈
F
(
A
)
r
a
n
k
(
A
′
)
d<min_{A^{'}\in{F(A)}}rank(A^{'})
d<minA′∈F(A)rank(A′)
则对应的RNN-based LM 产生的logits不可能属于
F
(
A
)
F(A)
F(A)。换句话说,此LM不可能找到一组参数
θ
\theta
θ,使其能recover真实概率分布A。
到底embedding size d能否满足上述不等式呢?我们已知,真实概率分布矩阵A也属于F(A),而且它是高秩的矩阵,其秩最大能和它的context数目相当($ 10^{5}$)。而embedding本就是为了精简输入维度而使用的,所以它的维度一般会较小(
1
0
2
10^2
102)。所以显然成立:
d
<
m
i
n
A
′
∈
F
(
A
)
r
a
n
k
(
A
′
)
d<min_{A^{'}\in{F(A)}}rank(A^{'})
d<minA′∈F(A)rank(A′)
即,RNN-based LM 不可能找到一组参数
Θ
\Theta
Θ ,使其能recover真实概率分布 A。它只是一个真实概率分布的低秩近似,表达能力不够,因此失去了一些模拟context间依赖性的能力。这也正是性能瓶颈所在。
Sloution for Softmax Bottleneck
Naive Solution
要解决这个瓶颈问题,一个最直观的方法就是提高embedding size d。但是这显然与embedding的目的不符。另一个方法是使用Ngram模型,来避免Softmax的使用。这两种方法都会使总参数数目急剧增加,容易导致过拟合,显然都不可取。
Mixture of Softmaxes
而另一种方法就是使用作者提出的 MoS(Mixture of Softmaxes) 来替代原始的 Softmax 。MoS的公式如下:
P
θ
(
X
∣
c
)
=
∑
k
=
1
K
π
c
,
k
e
x
p
(
h
c
,
k
T
w
x
)
∑
x
e
x
p
(
h
c
,
k
T
w
x
)
s
.
t
.
∑
k
=
1
K
π
c
,
k
=
1
P_{\theta}(X|c) = \sum^K_{k=1}\pi_{c,k}\frac{exp(h^T_{c,k}w_x)}{\sum_xexp(h^T_{c,k}w_x)} \ \ \ \ \ \ \ s.t. \ \sum^K_{k=1}\pi_{c,k}=1
Pθ(X∣c)=k=1∑Kπc,k∑xexp(hc,kTwx)exp(hc,kTwx) s.t. k=1∑Kπc,k=1
由名字可知,Mos便是把多个Softmax按权相加,综合为一个Softmax混合模型。
传统的RNN-based LM的结构如下左图,而基于MoS的RMM-LM 位于下图右。由比较可看出,仅在RNN的hidden state
h
t
h_t
ht 以后有所不同。
这两种不同的模型最后产生的下一符号概率分布的
log
\log
log也不同,如下:
A
^
M
o
S
=
log
∑
k
=
1
K
Π
k
exp
(
H
θ
,
k
W
θ
T
)
\widehat{A}_{MoS}=\log\sum^K_{k=1}\Pi_k\exp(H_{\theta,k}W^T_\theta)
A
MoS=logk=1∑KΠkexp(Hθ,kWθT)
A
^
S
o
f
t
m
a
x
=
log
exp
(
H
θ
W
θ
T
)
\widehat{A}_{Softmax}=\log\exp(H_{\theta}W^T_\theta)
A
Softmax=logexp(HθWθT)
A
^
M
o
S
\widehat{A}_{MoS}
A
MoS 这个优化版本由于引入了按权相加,因此在最后计算完
log
\log
log运算后,与模型产生的logits不再是原本的线性关系,理论上可以达到任意的高秩,因此提升了模型的表达能力。
Experiments
使用MoS的RNN与其他模型在LM问题上的表现对比如下:
Drawback
当然,MoS模型也有它的缺憾。由于使用了多个并行的Softmax按权相加,因此它的运算时间是原有模型的数倍。在实践中,其实Softmax Layer的计算是尤其费时的,因此这也算是不小的短板。由下图实验数据可知,MoS模型的计算时间与它所用的Softmax的数目K近似呈线性关系。
Summary
现在普遍使用的RNN-based LM,由于在最后把RNN输出的隐藏状态
h
t
h_t
ht乘以了output embedding matrix,并把得到的结果(logits)输入了softmax layer,导致最后整体模型所能模拟的概率分布空间的秩被embedding-size d 所限制。而MoS模型通过引入按权相加的运算打破了原来的线性关系,提高了模型模拟空间的秩。当然,其代价是线性增加的运算时间。
##REFERENCES
[1]Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen. Breaking the Softmax Bottleneck: A High-Rank RNN Language Model. In ICLR 2018.
[2]Anton Maximilian Schäfer and Hans Georg Zimmermann. Recurrent neural networks are universal approximators. In International Conference on Artificial Neural Networks, pp. 632–640. Springer, 2006.
[3]Tomas Mikolov, Martin Karafiát, Lukas Burget, Jan Cernocky, and Sanjeev Khudanpur. Recurrent neural network based language model. In Interspeech, volume 2, pp. 3, 2010.
[4]Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182, 2017.