【Review】Breaking the Softmax Bottleneck:A High-rank RNN Language Model

Content

此篇论文主要完成了:
1.通过数学推导,找到并证明了限制RNN-based LMs的性能瓶颈之一——Softmax Bottleneck问题
2.针对这个瓶颈,提出了一个解决方案—— Mixture of Softmaxes

Introduction

Language Modeling Problem

LM问题在指已知了一个符号(token)序列:
X = ( X 1 , . . . , X T ) X=(X_1,...,X_T) X=(X1,...,XT)
的情况下,生成一个Language模型来模拟这个序列出现的概率,即求解P(X)的值:
P ( X ) = ∏ t P ( X t ∣ X < T ) P(X)=\prod_tP(X_t|X_{<T}) P(X)=tP(XtX<T)
而根据链式法则(chain rule)和马尔科夫假设(Markov Assumption), P(X)的值可以通过求解它对应的联合概率得出:
P ( X ) = ∏ t P ( X t ∣ X < T ) = ∏ P ( X t ∣ C t ) P(X)=\prod_tP(X_t|X_{<T})=\prod{P(X_t|C_t)} P(X)=tP(XtX<T)=P(XtCt)
X t X_t Xt: 下一个符号的概率分布
C t C_t Ct:历史序列 / 已经出现的所有Tokens

因此,原始的LM问题就转换成了:在每个时刻t,根据已知的符号序列 C t C_t Ct(History), 求解下一时刻可能输出的符号的概率。 由于输出符号有很多种而可能性,所以这个概率的其实是一个在Vocabulary(或Token Set)上的概率分布。

Standard Approach: RNN based LMs

由于符号序列自带时间属性,而我们需要模拟的也是符号之间的时间依赖关系。因此对于LM问题来说,最标准的,而且state of art 模型都是基于RNN的。以下为一个基于RNN的Language Model的结构简图:
RNN based LMs
其中: h t = σ ( V h t − 1 + U x t ) h_t=\sigma(Vh_{t-1}+Ux_t) ht=σ(Vht1+Uxt)
o = W h t o=Wh_t o=Wht
P ( x t ∣ c t ) = y t = e o ( x t ) ∑ v e o ( v ) P(x_t|c_t)=y_t=\frac{e^{o(x_t)}}{\sum_ve^{o(v)}} P(xtct)=yt=veo(v)eo(xt)
首先图片左下角的符号序列 " the cat sat on the " 是我们在此时刻 t 已知的历史序列,即 C t C_t Ct 。由于输入的每个Token是由one-hot编码表示的,当Vocabulary很大的情况下,这个输入维度会非常的高。因此在处理这种高维输入时,会先使用word embedding matrix ( 图中矩阵U ) 来降低维度&学习词语的内部联系,使输入更有意义。
之后,经过处理的输入会被传送给RNN。基于此时刻 t 的输入和上一时刻RNN的旧隐藏状态 $ h_{t-1} $ , RNN会产生新的隐藏状态 $ h_t $ 。
此隐藏状态可能看作是一个由RNN学习到的,含有下一时刻输出符号的信息的特征。由于输出是一个基于Vocabulary的概率分布,因此我们必须把学习到的这个特征映射回初始的Vocabulary。上图中, h t h_t ht下的output embedding matrix (W) 就负责这个反映射。应用上,两个embedding 矩阵 U和W是一样的。
###Hypothesis&Main Issues
由Anton Maximilian Schäfer and Hans Georg Zimmermann写的Recurrent Neural Networks Are Universal Approximators论文可以得知,RNN的表达力是很强的,它可以模拟逼近任意的非线性动态系统(Universal approximation theorem)。由此作者推测出,基于RNN的LMs的性能瓶颈之一应该是RNN最后使用点乘+softmax操作,即:$o=Wh_t ; out = softmax(o) $。

Mathematical Analysis of LM

Defination

为了能进行数学推导和定量分析来证明这个假设,首先我们需要一个自然语言的数学表达。自然语言L可以表示成N个元组的集合:
L = { ( c 1 , P ∗ ( X ∣ c 1 ) ) , . . . , ( c N , P ∗ ( X ∣ c N ) ) } L=\{(c_1,P^*(X|c_1)),...,(c_N,P^*(X|c_N))\} L={(c1,P(Xc1)),...,(cN,P(XcN))}
其中:
c i : c_i: ci:代表了语言中的任一个可能的context(history token序列)
P ∗ ( X ∣ c i ) : P^*(X|c_i): P(Xci):真实的数据分布,即:已知一个历史符号序列( c i c_i ci),下一符号在Token集合 X X X上的概率分布
X = { x 1 , x 2 , . . . , x M } : X=\{x_1,x_2,...,x_M\}: X={x1,x2,...,xM}: 代表了语言L中所有可能出现的符号
N : N: N: 所有可能的上下文(符号组合)的数目
至此,LM问题可以转换成如下的数学公式表达:
P θ ( X ∣ c ) = P ∗ ( X ∣ c ) P_\theta(X|c)=P^*(X|c) Pθ(Xc)=P(Xc)
即,给定一个自然语言L,LM需要学习一组参数 θ \theta θ,基于此组参数的模型可以逼近真实的任一上下文(context)所对应的下一符号概率分布。
若我们使用RNN-based LMs, 那么在network的输出端,我们能从softmax layer 的输出直接得到基于此时刻 t 的下一符号概率分布 P θ ( X ∣ c ) P_{\theta}(X|c) Pθ(Xc) :
P θ ( X ∣ c ) = e x p ( h c T w x ) ∑ x e x p ( h c T w x ) P_\theta(X|c)=\frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)} Pθ(Xc)=xexp(hcTwx)exp(hcTwx)
因此,训练模型的Objective可以用以下等式表达:
P θ ( X ∣ c ) = e x p ( h c T w x ) ∑ x e x p ( h c T w x ) = P ∗ ( X ∣ c ) P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c) Pθ(Xc)=xexp(hcTwx)exp(hcTwx)=P(Xc)
即,我们使用一个RNN-based LM 来模拟每个可能context下的下一符号概率分布,并且不断优化模型使用的参数 θ \theta θ,使LM输出的概率分布逼近真实分布。

Matrix Factorization Problem

在数学化表达LM问题后,它的Objective公式还可以通过矩阵分解来做进一步的分析。
P θ ( X ∣ c ) P_{\theta}(X|c) Pθ(Xc)的表达式中, h c T h_c^T hcT代表了输入是不同的context(历史序列)的情况下,RNN所对应的不同隐藏状态。此处,可以把所有可能的情况列出,排列组合成一个矩阵:
H θ = [ h c 1 T h c 2 T . . . h c N T ] H_{\theta}=\left[ \begin{matrix} h^T_{c_1} \\ h^T_{c_2} \\ ... \\ h^T_{c_N} \end{matrix} \right] Hθ=hc1Thc2T...hcNT
这个矩阵包含了RNN针对不同的Context的所有可能的隐藏状态。根据此节开篇的假设,自然语言L一共有 N N N种可能的context(即:符号组合序列)。
相似地,公式中的 w x w_x wx也可以统一成矩阵表达:
W θ = [ w x 1 T w x 2 T . . . w x M T ] W_{\theta}= \left[ \begin{matrix} w^T_{x_1} \\ w^T_{x_2} \\ ... \\ w^T_{x_M} \end{matrix} \right] Wθ=wx1Twx2T...wxMT
W θ W_{\theta} Wθ中的每一行代表了语言L中的某一个符号 x i x_i xi所对应的embedding coefficient,用以把RNN学到的隐藏状态映射回包含 X X X符号集的Vocabulary空间。同样,根据此节开篇的假设,自然语言L一共有 M M M种可能的符号(tokens)。
最后,我们还需要把自然语言L真实的条件概率分布(在各种可能的context下,下一符号的概率分布)用矩阵的方式表达,从而能使用矩阵知识,数学地分析RNN-based LMs。此处假设矩阵 A A A代表了真实条件概率分布 P ∗ ( X ∣ c ) P^*(X|c) P(Xc) log ⁡ \log log后的结果:
A = [ log ⁡ P ∗ ( x 1 ∣ c 1 ) log ⁡ P ∗ ( x 2 ∣ c 1 ) . . . log ⁡ P ∗ ( x M ∣ c 1 ) log ⁡ P ∗ ( x 1 ∣ c 2 ) log ⁡ P ∗ ( x 2 ∣ c 2 ) . . . log ⁡ P ∗ ( x M ∣ c 2 ) . . . . . . . . . . . . log ⁡ P ∗ ( x 1 ∣ c N ) log ⁡ P ∗ ( x 2 ∣ c N ) . . . log ⁡ P ∗ ( x M ∣ c N ) ] A= \left[ \begin{matrix} \log{P^*(x_1|c_1)} &\log{P^*(x_2|c_1)}&...&\log{P^*(x_M|c_1)}\\ \log{P^*(x_1|c_2)} &\log{P^*(x_2|c_2)}&...&\log{P^*(x_M|c_2)} \\ ...&...&...&... \\ \log{P^*(x_1|c_N)} &\log{P^*(x_2|c_N)}&...&\log{P^*(x_M|c_N)} \end{matrix} \right] A=logP(x1c1)logP(x1c2)...logP(x1cN)logP(x2c1)logP(x2c2)...logP(x2cN)............logP(xMc1)logP(xMc2)...logP(xMcN)
由上公式可知, A A A包含了context与对应next token的所有可能的组合。

Rank Analysis

在经历了上述对Objective的分析及矩阵转换,RNN-based LM问题事实上可以抽象如下:
∃ θ , log ⁡ ( S o f t m a x ( H θ W θ T ) ) = A \exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A θ,log(Softmax(HθWθT))=A
即,通过学习,我们希望找到一组参数 θ \theta θ,以它为参数的LM模型(即RNN)可以逼近真实的下一符号概率分布的 log ⁡ \log log
为了能推导出Softmax存在的瓶颈,首先先要引入一个矩阵操作 r o w − w i s e   s h i f t row-wise\ shift rowwise shift。对一个矩阵 A 进行 r o w − w i s e   s h i f t row-wise\ shift rowwise shift操作,其结果为一个矩阵集合 F ( A ) : F(A): F(A):
F ( A ) = { A + Λ J N , M ∣ Λ   i s   d i a g o n a l   a n d   R N × N } F(A)=\{ A+\Lambda J_{N,M}| \Lambda \ is\ diagonal\ and \ R^{{N}\times{N}}\} F(A)={A+ΛJN,MΛ is diagonal and RN×N}
其中:
J N , M : J_{N,M}: JN,M维度对应的全1矩阵
Λ : \Lambda: Λ对角线元素值任意的对角线矩阵
事实上, r o w − w i s e   s h i f t row-wise\ shift rowwise shift的作用是把矩阵 A A A中的每行元素上加上任意一个实数,例如如下 Λ J N , M \Lambda J_{N,M} ΛJN,M与矩阵 A A A相加后, A A A的第 i 行会被加上一个实数 a i a_i ai
[ a 1 0 0 0 a 2 0 0 0 a 3 ] Λ 3 × 3 × [ 1 1 1 1 1 1 1 1 1 ] J 3 × 4 = [ a 1 a 1 a 1 a 2 a 2 a 2 a 3 a 3 a 3 ] \left[ \begin{matrix} a_1&0&0 \\ 0&a_2&0 \\ 0&0&a_3 \end{matrix} \right]_{\Lambda^{3\times3} } \times{\left[ \begin{matrix} 1&1&1 \\ 1&1&1 \\ 1&1&1 \end{matrix} \right]_{J^{3\times4}}}={\left[ \begin{matrix} a_1&a_1&a_1 \\ a_2&a_2&a_2 \\ a_3&a_3&a_3 \end{matrix} \right]} a1000a2000a3Λ3×3×111111111J3×4=a1a2a3a1a2a3a1a2a3
而代表真实下一符号概率分布的 log ⁡ \log log 矩阵 A A A A A A 经由 r o w − w i s e   s h i f t row-wise\ shift rowwise shift 所得到的矩阵集合 F ( A ) F(A) F(A) ,有如下两个特殊性质:
1.所有真实数据分布所对应的logits都包含在了集合 F ( A ) F(A) F(A)中。
2. F ( A ) F(A) F(A) 中的所有矩阵的秩
都相似,相差不大于1。
附–矩阵的秩 :
-定义: 矩阵中所有线性独立的列的数目和
-直观解释:如果一个矩阵有着更高的秩,那么说明它有更多的线性独立的列。若把这些列看作是一组 basis vectors ,那么它们所能表达的空间就更复杂,表达能力就更强。即,高秩的矩阵能包含更多的信息量。
-例子:如果我们把某自然语言L表示成矩阵形式(如上节中的矩阵 A A A),那么此矩阵 A A A天然拥有高秩的性质,例如:
-它是高度依赖上下文的——“南”后面的符号可以是“京”或者“瓜”,取决于前后文是关于地理的还是农业的。即,在不同的上下文里,下一符号的概率分布会非常不同。
-并且我们不可能找到一组有限数目的basis vectors,使用此基来表达语言L中的所有Token的关系。

review

由RNN-based LM的结构推导出,它的Objective如下:
P θ ( X ∣ c ) = e x p ( h c T w x ) ∑ x e x p ( h c T w x ) = P ∗ ( X ∣ c ) P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c) Pθ(Xc)=xexp(hcTwx)exp(hcTwx)=P(Xc)
通过把自然语言表达成矩阵形式,再进行矩阵分解(Matrix Factorization ),LM的目标可以抽象成如下表达。即,LM需要找到一组参数,借由这组参数生成的下一符号概率能无限逼近真实概率:
∃ θ , log ⁡ ( S o f t m a x ( H θ W θ T ) ) = A \exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A θ,log(Softmax(HθWθT))=A
而通过引入矩阵运算符 row-wise shift ,以及此运算产生的矩阵集F(A)的第一个性质,我们可以推出,若RNN-based LM真的能逼近真实概率分布,那么它产生的 logits 必定属于真实概率分布矩阵 Arow-wise shift 结果集合中。即,Objective为如下:
∃ θ , s u c h   t h a t , H θ W θ T ∈ F ( A ) \exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)} θ,such that,HθWθTF(A)

Problem: Softmax Bottleneck

至此,LM问题的核心变成了研究是否真的存在一组参数 θ , \theta, θ,使基于此 θ \theta θ的LM所产生的logits属于 F ( A ) F(A) F(A) ,如下:
∃ θ , s u c h   t h a t , H θ W θ T ∈ F ( A ) \exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)} θ,such that,HθWθTF(A)
回忆一下,如上公式中:
H θ ∈ R N × d , H_{\theta}\in{R^{N\times{d}}}, HθRN×d,代表了所有可能的context输入下的对应隐藏状态。
W θ T ∈ R M × d , W^T_{\theta}\in{R^{M\times{d}}}, WθTRM×d,代表了语言中所有可能的token所对应embedding coefficient
因此,由线性代数的知识可知,它们乘积的秩应该小于d,即:
r a n k ( H θ W θ T ) ≤ d rank(H_{\theta}W^T_{\theta})\leq{d} rank(HθWθT)d
(相较于自然语言中的context数目N和token数目M,embedding size d显然会小很多)
又由于row-wise shift的第二个性质(即: F ( A ) F(A) F(A)中的所有矩阵的秩都相似,相差不大与1)可推导出,若embedding size d有:
d < m i n A ′ ∈ F ( A ) r a n k ( A ′ ) d<min_{A^{'}\in{F(A)}}rank(A^{'}) d<minAF(A)rank(A)
则对应的RNN-based LM 产生的logits不可能属于 F ( A ) F(A) F(A)。换句话说,此LM不可能找到一组参数 θ \theta θ,使其能recover真实概率分布A
到底embedding size d能否满足上述不等式呢?我们已知,真实概率分布矩阵A也属于F(A),而且它是高秩的矩阵,其秩最大能和它的context数目相当($ 10^{5}$)。而embedding本就是为了精简输入维度而使用的,所以它的维度一般会较小( 1 0 2 10^2 102)。所以显然成立:
d < m i n A ′ ∈ F ( A ) r a n k ( A ′ ) d<min_{A^{'}\in{F(A)}}rank(A^{'}) d<minAF(A)rank(A)
即,RNN-based LM 不可能找到一组参数 Θ \Theta Θ ,使其能recover真实概率分布 A。它只是一个真实概率分布的低秩近似,表达能力不够,因此失去了一些模拟context间依赖性的能力。这也正是性能瓶颈所在。

Sloution for Softmax Bottleneck

Naive Solution

要解决这个瓶颈问题,一个最直观的方法就是提高embedding size d。但是这显然与embedding的目的不符。另一个方法是使用Ngram模型,来避免Softmax的使用。这两种方法都会使总参数数目急剧增加,容易导致过拟合,显然都不可取。

Mixture of Softmaxes

而另一种方法就是使用作者提出的 MoS(Mixture of Softmaxes) 来替代原始的 Softmax 。MoS的公式如下:
P θ ( X ∣ c ) = ∑ k = 1 K π c , k e x p ( h c , k T w x ) ∑ x e x p ( h c , k T w x )         s . t .   ∑ k = 1 K π c , k = 1 P_{\theta}(X|c) = \sum^K_{k=1}\pi_{c,k}\frac{exp(h^T_{c,k}w_x)}{\sum_xexp(h^T_{c,k}w_x)} \ \ \ \ \ \ \ s.t. \ \sum^K_{k=1}\pi_{c,k}=1 Pθ(Xc)=k=1Kπc,kxexp(hc,kTwx)exp(hc,kTwx)       s.t. k=1Kπc,k=1
由名字可知,Mos便是把多个Softmax按权相加,综合为一个Softmax混合模型。
传统的RNN-based LM的结构如下左图,而基于MoSRMM-LM 位于下图右。由比较可看出,仅在RNNhidden state h t h_t ht 以后有所不同。
standard RNN vs. MoS
这两种不同的模型最后产生的下一符号概率分布的 log ⁡ \log log也不同,如下:
A ^ M o S = log ⁡ ∑ k = 1 K Π k exp ⁡ ( H θ , k W θ T ) \widehat{A}_{MoS}=\log\sum^K_{k=1}\Pi_k\exp(H_{\theta,k}W^T_\theta) A MoS=logk=1KΠkexp(Hθ,kWθT)
A ^ S o f t m a x = log ⁡ exp ⁡ ( H θ W θ T ) \widehat{A}_{Softmax}=\log\exp(H_{\theta}W^T_\theta) A Softmax=logexp(HθWθT)
A ^ M o S \widehat{A}_{MoS} A MoS 这个优化版本由于引入了按权相加,因此在最后计算完 log ⁡ \log log运算后,与模型产生的logits不再是原本的线性关系,理论上可以达到任意的高秩,因此提升了模型的表达能力。

Experiments

使用MoS的RNN与其他模型在LM问题上的表现对比如下:
result

Drawback

当然,MoS模型也有它的缺憾。由于使用了多个并行的Softmax按权相加,因此它的运算时间是原有模型的数倍。在实践中,其实Softmax Layer的计算是尤其费时的,因此这也算是不小的短板。由下图实验数据可知,MoS模型的计算时间与它所用的Softmax的数目K近似呈线性关系。
drawback

在这里插入图片描述

Summary

现在普遍使用的RNN-based LM,由于在最后把RNN输出的隐藏状态 h t h_t ht乘以了output embedding matrix,并把得到的结果(logits)输入了softmax layer,导致最后整体模型所能模拟的概率分布空间的秩被embedding-size d 所限制。而MoS模型通过引入按权相加的运算打破了原来的线性关系,提高了模型模拟空间的秩。当然,其代价是线性增加的运算时间。
##REFERENCES
[1]Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen. Breaking the Softmax Bottleneck: A High-Rank RNN Language Model. In ICLR 2018.
[2]Anton Maximilian Schäfer and Hans Georg Zimmermann. Recurrent neural networks are universal approximators. In International Conference on Artificial Neural Networks, pp. 632–640. Springer, 2006.
[3]Tomas Mikolov, Martin Karafiát, Lukas Burget, Jan Cernocky, and Sanjeev Khudanpur. Recurrent neural network based language model. In Interspeech, volume 2, pp. 3, 2010.
[4]Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182, 2017.

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值