这篇文章是在Outrageously Large Neural Networks: The Sparsely-Gated Mixtured-of-Experts Layer 基础上对 MoE 的优化,上篇文章具体内容见链接帖子。
1. Introduction
文章提出,在之前的工作中,通过门控让token选择top-k个专家并输入,即每个token经过的专家数量是固定的。这种独立tokne选择往往会导致expert负载不平衡导致模型的训练效率低下和次优训练。之前的方法通过引入稀疏门控网络及其对应的loss正则项,用于防止太多的token被router分配给单个专家,但效果仍然有限。
本文则提出让专家反过来选择top-k个token,解决不同专家得到的token数量不同导致部分专家训练不充分的问题,以及不同token经过的专家数量相同忽略了token对于任务重要性不同的问题。
具体来说,MoE的每个专家都是一个子网络,每个输入的token通过router分配到不同专家网络。router的分配方式可以选择k-means聚类、基于token-expert相关性的线性分配、哈希等。
图左是之前工作的方法,图右是本文方法
2. Token-Choice Routing 的缺陷
同上,可以归结为三个点:
- 负载不平衡:由于token分配不均,部分专家被分配的token过少导致训练、利用得不够充分;部分专家被分配的token过多,但由于内存限制只能选择一定数量的token使用,导致token资源的浪费。
- 冗余专业化/专业化不足:理想情况下,学习的门控网络应该生成addinate打分,将相似或相关的token分配到同一个专家。次优策略可能会产生冗余的专家或不够专业的专家。
- 每个token经过的计算量相同,忽略token重要性的不同。
3. Expert-Choice 的方法
就像前面所说,方法很简单,即将每个token选择top-k个expert,变为每个expert选择top-k个token。
具体实验中,取
k
=
n
⋅
c
e
k=\frac{n\cdot c}{e}
k=en⋅c
其中, n n n 是一个batch的token数, c c c 是超参数表示一个token平均使用几个expert, e e e 是专家数。
- 输入一个batch的 X ∈ R n × d X\in R^{n\times d} X∈Rn×d
- 设置索引矩阵 I ∈ R e × k I\in R^{e\times k} I∈Re×k, I [ i , j ] I[i,j] I[i,j]表示第i个expert的第j个选择token
- 设置与 I I I对应的矩阵 P ∈ R e × k × n P\in R^{e\times k\times n} P∈Re×k×n,作为 I I I的one-hot版本
- 门控矩阵 G ∈ R e × k G\in R^{e\times k} G∈Re×k用于选择权重
router:
S
=
S
o
f
t
m
a
x
(
X
⋅
W
g
)
S=Softmax(X\cdot W_g)
S=Softmax(X⋅Wg)
其中,
S
S
S是token和expert匹配的打分,
W
g
W_g
Wg是expert embedding作为学习参数。
G
,
I
=
T
o
p
k
(
S
T
,
k
)
P
=
O
n
e
h
o
t
(
I
)
G, I=Topk(S^T,k) \\ P=Onehot(I)
G,I=Topk(ST,k)P=Onehot(I)
为 S T S^T ST的每一行选择k个最大的得分, I I I是选择结果的索引, G G G是结果对应的权重。
expert:
X
i
n
=
P
⋅
X
X_{in}=P\cdot X
Xin=P⋅X
其中,
X
i
n
X_{in}
Xin是作为专家的FFN的输入,
X
i
n
[
i
]
∈
R
k
×
d
X_{in}[i]\in R^{k\times d}
Xin[i]∈Rk×d表示第i个专家的输入。
X
e
[
i
]
=
G
e
L
U
(
X
i
n
[
i
]
⋅
W
1
[
i
]
)
⋅
W
2
[
i
]
X
o
u
t
[
l
,
d
]
=
∑
i
,
j
P
[
i
,
j
,
l
]
G
[
i
,
j
]
X
e
[
i
,
j
,
d
]
X_e[i]=GeLU(X_{in}[i]\cdot W_1[i])\cdot W_2[i] \\ \ \\ X_{out}[l,d]=\sum_{i,j}P[i,j,l]G[i,j]X_e[i,j,d]
Xe[i]=GeLU(Xin[i]⋅W1[i])⋅W2[i] Xout[l,d]=i,j∑P[i,j,l]G[i,j]Xe[i,j,d]
FFN的输出即为 X o u t X_{out} Xout。
方法并不困难,思路简单。