选自arxiv
作者:杨植麟、Thang Luong等
机器之心编译
参与:魔王、杜伟
2017 年,杨植麟等人提出一种解决 Softmax 瓶颈的简单有效的方法——Mixture of Softmaxes(MoS)。但该方法成本高昂,于是最近杨植麟等人再次瞄准 softmax 瓶颈问题,提出兼顾表达能力和高效性的新方法 Mixtape。
论文链接:https://papers.nips.cc/paper/9723-mixtape-breaking-the-softmax-bottleneck-efficiently.pdf
softmax 瓶颈限制了神经语言模型的表达能力(expressiveness)。Mixture of Softmaxes (MoS) 是解决该理论局限的有效方法,但与 softmax 相比,MoS 无论在内存还是时间上都成本较高。
来自 CMU 和谷歌大脑的杨植麟、Thang Luong、Ruslan Salakhutdinov 和 Quoc Le 提出了一种新方法 Mixtape,该输出层利用三项新技术——logit 空间向量门控(logit space vector gating)、sigmoid 树分解(sigmoid tree decomposition)和门控共享(gate sharing),更高效地打破了 softmax 瓶颈。
研究者在语言建模和机器翻译等四个基准数据集上进行了实验,结果表明 Mixtape 层与 MoS 层性能相当,且 Mixtape 的效率是后者的 3.5-10.5 倍。在 10-30K 的词汇量下,使用 Mixtape 的网络仅比基于 softmax 的网络慢 20%-34%,其困惑度和翻译质量均优于 softmax。
softmax 带给我们的苦与乐
大量神经网络使用 softmax 作为标准输出层,包括大部分神经语言模型。但是,正如杨植麟等人在
之前研究 [19]
中所指出的,softmax 限制了神经语言模型的表达能力,因为它将输出表示限制在低秩,这不足以建模自然语言的复杂性。该局限叫做「softmax 瓶颈」。
为了打破这一瓶颈,[19] 提出新方法 Mixture of Softmaxes (MoS),它将离散潜在变量引入到输出层中,通过 log-sum-exp 非线性变换使对数概率矩阵变为高秩。但是,MoS 的内存和时间成本均高于 softmax,这使得它在计算资源有限的情况下实际应用性减弱。
MoS →Mixtape
为了降低 MoS 的计算成本,最近杨植麟等人提出了一种高效解决 softmax 瓶颈的新型输出层 Mixtape。Mixtape 可作为额外层嵌入到任意现有网络的交叉熵损失函数之前。它不像 MoS 在概率空间中部署标量混合(scalar mixture),而是在 logit 空间中应用向量门控机制,以避免使用多个成本高昂的 softmax。
此外,Mixtape 还使用了另外两项新技术来进一步降低计算成本。
首先,向量门控机制成本高昂,因为该机制需要我们对词汇表中的每一个词计算 softmax 门控。为此,研究者提出了 sigmoid 树分解技术,将 softmax 概率门控分布分解为一个深度为 2 的二叉树结构,每一个分支所包含的概率值部分由 sigmoid 函数决定。sigmoid 树分解更加高效,因为它避免了 softmax 中的减除运算。
另一项技术是门控共享,即对所有低频词共享门控值,得到部分高秩的表示。该技术在不影响性能的情况下,节约了一定量的内存和计算资源,因为即使没有门控共享,低频词的门控值通常也很难准确估计。
Mixtape 有多强大
Mixtape 结合了以上三项技术,效率较 MoS 有显著提升,而且 Mixtape 在四个基准数据集上的性能堪比甚至超过后者。在正常词汇量情况下(如 10K-30K),基于相同的批大小,Mixtape 层的速度是 MoS 层的 1.6-11.5 倍;基于相同的内存,Mixtape 层的速度是 MoS 的 3.5-10.5 倍。
在正常词汇量情况下,基于相同的批大小,使用 Mixtape 的网络速度仅比使用 softmax 的网络慢 5%-18%;基于相同的内存,前者仅比后者慢 20%-34%。在包含 100K token 的大词汇量情况下,基于 Mixtape 的网络的速度仅比基于 softmax 的网络慢 60%。
Mixtape 和 MoS 在困惑度和翻译质量方面均优于 softmax。有趣的是,这些基准数据集各自的词汇量从 10K 到 100K 不等,其输入表示也不尽相同(包括单词和 BPE 子词),这表明 Mixtape 对不同输入都很高效,具备稳健性。
Mixtape 高效解决 softmax 瓶颈的奥秘
softmax 瓶颈问题的定义。
2017 年,CMU 杨植麟等人提出了解决 Softmax 瓶颈的简单有效方法 MoS。近期,杨植麟等人为了解决效率低下问题,再次提出了新方法 Mixtape,该方法既能和 MoS 一样学习高秩表示,其效率又高于 MoS。
如前所述,Mixtape 运用了三项新技术,接下来我们就来查看其细节。
logit 空间向量门控
MoS 成本最高的部分是计算 K 个 softmax,如果我们能够只用一个 softmax 计算最终概率分布,就可以节约大量计算资源。我们很容易想到将混合(mixture)从概率空间移到 logit 空间,即在 softmax 运算之前混合表示,从而得到条件分布
。但是,正如论文 [19] 所述,该公式会导致低秩表示,因为它仍然使用公式 (1) 中的矩阵分解。
于是,该研究做了一个小小的修改:在 logit 空间中使用混合运算,从而得到高秩表示。其关键思想是:使用向量门控机制,而不是标量混合。也就是说,该方法未对每一个 token 使用共享混合权重集,而是对不同的 token 应用不同的权重集。使用向量门控后,条件分布的公式可写为:
但是,在 Mixtape 取得高效率的路上还有一个障碍。对于每一个 context-token 对 (c, x),先验 π_c,x,k 需要执行归一化操作,而这需要对每一对的先验概率执行 softmax 运算。
Sigmoid 树分解
为了高效计算先验 π_c,x,k,研究者不使用 softmax,而是提出一种新技术——将 softmax 分布分解为 sigmoid 函数树结构。具体而言,计算 (K − 1) 个 sigmoid 输出,并利用它们定义树分支的概率。例如,当 K = 4 时,先验被定义为:
其中 γ_∗ 表示 sigmoid 概率,σ 表示 sigmoid 函数。
这就是 sigmoid 树分解。这种分解能够通过 (K − 1) 个 sigmoid 函数完全恢复 K-way 概率分布。使用 sigmoid 函数可以移除 softmax 中的减除运算,更加高效。
令 g_c 作为语境 c 下的 d_1 维度最后一层隐藏状态,则预激活先验(pre-activation prior)l_∗ 的计算公式为:
其中 v_x ∈ R^(d_2)、U_k ∈ R^(d_2×d_1)、u_k ∈ R^(d_1)、b_x,k ∈ R 是模型参数。d_2 是表示门控嵌入大小的超参数,通常比正常的词嵌入大小 d 要小。语境嵌入可以通过以下公式得到:
其中 H_k ∈ R^(d×d_1) 是模型参数。
门控共享
通过上述两个方法已经可以得到高效的高秩模型,但仍然存在改进空间。研究者观察到,我们仍需要对词汇表中的每个 token 计算门控先验,而这成为影响效率的瓶颈。但是,由于缺乏训练样本,我们很难估计低频 token 的门控先验,因此学习低频 token 的门控先验可能只是对算力的浪费。基于此,研究者提出门控共享,即对所有低频词共享相同的门控先验。具体而言,对于低频 token x,预激活门控先验被定义为:
使用门控共享后,研究者可使用共享门控先验来混合语境嵌入 h_c,k,然后再与 token 嵌入 w_x 相乘,由于低频 token 无需存储门控 logits 从而节省了内存空间。门控共享还加速计算,因为所有低频 token 仅计算一组门控先验。
Mixtape 奥秘汇总
Mixtape 层可总结为:
- 给出最后一层的隐藏状态 g_c,使用公式 (5) 计算语境嵌入 h_c,k;
- 对每个高频 token x,使用公式 (4) 计算预激活门控先验 l_c,x,k;
- 对于所有低频 token,使用公式 (6) 计算预激活门控先验 l_c,x,k;
- 使用 sigmoid 树分解,计算公式 (3) 中的门控先验 π_c,x,k;
- 使用向量门控,利用公式 (2) 获得下一个 token 的概率。
Mixtape 层的架构如下图所示:
图 1:Mixtape 层架构图。
实验
实验包括三部分:
- Mixtape 层打破 softmax 瓶颈,从而改进了当前最优的机器翻译系统;
- 研究者对比了 Mixtape、MoS 和 softmax 的困惑度、翻译质量、速度和内存约束,证明 Mixtape 能够在效果和效率之间做好权衡;
- 控制变量实验证明了门控共享的优势。
表 1:在 WMT 英德和英法语言对数据上的性能对比。Mixtape 在这两项任务上分别使用了 2 亿和 8 亿参数。
表 2:数据集统计数据概况。「PTB」和「1B」分别表示 Penn Treebank 数据集和 One Billion Word 数据集。
表 3:模型在 Penn Treebank 上的困惑度和训练时间对比情况。
表 4:模型在 One Billion Word 数据集上的困惑度和训练时间对比情况。
表 5:模型在 WMT』14 英德语言对数据上的 BLEU 值和训练时间对比。
表 6:模型在 WMT』14 英法语言对数据上的 BLEU 值和训练时间对比。