class MMoE_Layer(tf.keras.layers.Layer):
def __init__(self,expert_dim,n_expert,n_task):
super(MMoE_Layer, self).__init__()
self.n_task = n_task
self.expert_layer = [Dense(expert_dim,activation = 'relu') for i in range(n_expert)]
self.gate_layers = [Dense(n_expert,activation = 'softmax') for i in range(n_task)]
def call(self,x):
# 构建多个专家网络
E_net = [expert(x) for expert in self.expert_layer]
E_net = Concatenate(axis = 1)([e[:,tf.newaxis,:] for e in E_net]) # 维度 (bs,n_expert,n_dims)
# 构建多个门网络
gate_net = [gate(x) for gate in self.gate_layers] # 维度 n_task个(bs,n_expert)
# towers计算:对应的门网络乘上所有的专家网络
towers = []
for i in range(self.n_task):
g = tf.expand_dims(gate_net[i],axis = -1) # 维度(bs,n_expert,1)
_tower = tf.matmul(E_net, g,transpose_a=True)
towers.append(Flatten()(_tower)) # 维度(bs,expert_dim)
return towers
核心代码块