发表于:neurips23
推荐指数: #paper/⭐⭐⭐
和node-wise 那个思路一样(或者说node-wise是这个的进一步延伸)
但是,他的公式有点让人不能最清晰的读懂(似懂非懂)
主干模型
即为每个节点(或者每个子图)选择相应的专家
模型是通过采样子图来构建的.
GMoE自适应选择1层或者2层的专家来动态的捕获端或长的邻居
h
i
′
=
σ
(
∑
o
=
1
m
∑
j
∈
N
i
G
(
h
i
)
o
E
o
(
h
j
,
e
i
j
,
W
)
+
∑
o
=
m
n
∑
j
∈
N
i
2
G
(
h
i
)
o
E
o
(
h
j
,
e
i
j
,
W
)
)
h_i^{\prime}=\sigma\left(\sum_{o=1}^m\sum_{j\in N_i}G(h_i)_oE_o\left(h_j,e_{ij},W\right)+\sum_{o=m}^n\sum_{j\in N_i^2}G(h_i)_oE_o\left(h_j,e_{ij},W\right)\right)
hi′=σ
o=1∑mj∈Ni∑G(hi)oEo(hj,eij,W)+o=m∑nj∈Ni2∑G(hi)oEo(hj,eij,W)
m是1跳专家数,m-n表示二跳专家数.
E
o
E_{o}
Eo表示第o个消息传递(即专家GNN).G是门控函数去生成多倍的决策分数,
G
(
h
i
)
o
G(h_{i})_{o}
G(hi)o表示第o层G的输出特征.具体的表示如下:
G
(
h
i
)
=
S
o
f
t
m
a
x
(
T
o
p
K
(
Q
(
h
i
)
,
k
)
)
,
Q
(
h
i
)
=
h
i
W
g
+
ϵ
⋅
S
o
f
t
p
l
u
s
(
h
i
W
n
)
,
G(h_i)=\mathrm{Softmax}(\mathrm{TopK}(Q(h_i),k)),\\Q(h_i)=h_iW_g+\epsilon\cdot\mathrm{Softplus}(h_iW_n),
G(hi)=Softmax(TopK(Q(hi),k)),Q(hi)=hiWg+ϵ⋅Softplus(hiWn),
k表示选择的专家数.
ϵ
∈
N
(
0
,
1
)
\epsilon\in\mathcal{N}(0,1)
ϵ∈N(0,1)表示标准高斯噪音.
h
i
h_{i}
hi的维度为
b
×
s
b\times s
b×s,其中b为batchsize,s为向量维度.
W
g
∈
R
s
×
n
W_g\in\mathbb{R}^{s\times n}
Wg∈Rs×n,
W
n
∈
R
s
×
n
W_n\in\mathbb{R}^{s\times n}
Wn∈Rs×n.(可以理解为:通过W,让
h
i
h_{i}
hi映射到
b
×
n
b\times n
b×n上,即可以评价n维的专家.)
模型的问题以及增补
但是,这会行程一个问题:首先被选中的专家会比其他专家的权重概率大.因此,为了让专家在初始时设置的更公平,我们可以用如下的设置去平衡:
Importance
(
H
)
=
∑
h
i
∈
H
,
g
∈
G
(
h
i
)
g
,
L
importance
(
H
)
=
C
V
(
Importance
(
H
)
)
2
\operatorname{Importance}(H)=\sum_{h_i\in H,g\in G(h_i)}g,\begin{array}{c}L_{\text{importance}}(H)=CV(\operatorname{Importance}(H))^2\end{array}
Importance(H)=hi∈H,g∈G(hi)∑g,Limportance(H)=CV(Importance(H))2
importance(H)被定义整个batch节点门控值g的和.CV代表了变量的系数.L测量重要性分数.
我们让
G
(
h
i
)
≠
0
G(h_{i})\neq_{}0
G(hi)=0当且仅当
Q
(
h
i
)
o
Q(h_{i})_{o}
Q(hi)o比
Q
(
h
i
)
Q(h_{i})
Q(hi)的第k大个元素大.(即前k个专家才有值,后面的全为0)
P
(
h
i
,
o
)
=
P
r
(
Q
(
h
i
)
o
>
k
t
h
e
x
(
Q
(
h
i
)
,
k
,
o
)
)
P(h_i,o)=Pr(Q(h_i)_o>\mathrm{kth_ex}(Q(h_i),k,o))
P(hi,o)=Pr(Q(hi)o>kthex(Q(hi),k,o))
其中,kth_ex()代表除了自己外第k-th大的元素.
P
(
h
i
,
o
)
P(h_{i},o)
P(hi,o)可以被简化为:
P
(
h
i
,
o
)
=
Φ
(
h
i
W
g
−
k
t
h
e
x
(
Q
(
h
i
)
,
k
,
o
)
Softplus
(
h
i
W
n
)
)
,
P(h_i,o)=\Phi\left(\frac{h_iW_g-\mathrm{kth_ex}(Q(h_i),k,o)}{\text{Softplus}(h_iW_n)}\right),
P(hi,o)=Φ(Softplus(hiWn)hiWg−kthex(Q(hi),k,o)),
ϕ
\phi
ϕ是标准正态分布的CDF.
L
l
o
a
d
(
H
)
=
C
V
(
∑
h
i
∈
H
,
p
∈
P
(
h
i
,
o
)
p
)
2
.
L_{\mathrm{load}}(H)=CV(\sum_{h_i\in H,p\in P(h_i,o)}p)^2.
Lload(H)=CV(hi∈H,p∈P(hi,o)∑p)2.
其中,p是batch的node-wise近似性.
L
=
L
E
M
+
λ
(
L
l
o
a
d
(
H
)
+
L
importance
(
H
)
)
L=L_{EM}+\lambda(L_{\mathrm{load}}(H)+L_{\text{importance}}(H))
L=LEM+λ(Lload(H)+Limportance(H))
L
E
M
L_{EM}
LEM表示MOE具体任务的期望最大化损失.
λ
\lambda
λ是平衡超参
预训练
L ( H , M ) = D ( d ( f ( M ⋅ H ) ) , H ) \mathcal{L}\left(H,M\right)=D\left(d\left(f\left(M\cdot H\right)\right),H\right) L(H,M)=D(d(f(M⋅H)),H)
D是距离矩阵, f f f是编码器,d是解码器.M是mask矩阵.利用自编码器来进行预训练.
GMOE结果:
利用GMOE预训练来加强模型: