概述
MMoE (Multi-gate Mixture-of-Experts) 是一种多任务学习模型,主要用于处理多个相关但不完全相同的任务;例如搜索/广告/信息流排名中的点击和转化,在传统的机器学习方法中,往往需要针对不同的任务构建不同的模型,同时还要为不同的模型构建匹配的数据流(pipeline);这种情况下每个模型都是较为独立的,这会丧失不同模型之间相互作用的机会,从而导致两个具有相关性的任务之间不能产生相互作用;而MMoE作为上述问题的解决方案能够同时针对多个任务目标进行优化。
如下图所示,在早期的设计中,一个Tower就代表一个task,多个Tower会共享一个Shared-Bottom,这样的设计结构清晰明了,但是并没有过多的考虑任务之间的相关性;后来的OMoE(one-gate MoE)对这方面进行了优化,使用多组专家网络(Expert)替代Shared-Bottom,同时设计门控网络(Gate)从Input中学习不同Expert的作用权重,并将加权后的Expert结果输入到不同的Tower中,OMoE更擅长使用单个门控网络来寻找在所有任务中利用专家输出的最佳方式;MMoE为每个Tower都分配了一个Gate,每个门都学习一种特定的方式来利用专家的产出来完成其各自的任务。
资源推荐
此处推荐一个开源库,该存储库涵盖了rank领域过去十年的 SOTA 模型包括MMoE:
- torch版本的实现 https://github.com/shenweichen/DeepCTR-Torch
- tensorflow版本的实现 https://github.com/shenweichen/DeepCTR
为了便于理解,接下来我们探索一下MMoE的核心结构,并结合核心结构给出最基本的代码实现。
核心结构:
- Expert Networks(专家网络)
- Gating Networks(门控网络)
- Task-specific Towers(任务专用塔)
基本实现:
import torch
import torch.nn as nn
class Expert(nn.Module):
def __init__(self, input_size, expert_size):
super(Expert, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_size, expert_size),
nn.ReLU(),
nn.Linear(expert_size, expert_size),
nn.ReLU()
)
def forward(self, x):
return self.net(x)
class GatingNetwork(nn.Module):
def __init__(self, input_size, num_experts):
super(GatingNetwork, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_size, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x):
return self.net(x)
class TaskTower(nn.Module):
def __init__(self, expert_size, tower_size, output_size):
super(TaskTower, self).__init__()
self.net = nn.Sequential(
nn.Linear(expert_size, tower_size),
nn.ReLU(),
nn.Linear(tower_size, output_size)
)
def forward(self, x):
return self.net(x)
class MMoE(nn.Module):
def __init__(self, input_size, num_experts, expert_size, num_tasks, tower_size, output_size):
super(MMoE, self).__init__()
# 专家网络
self.experts = nn.ModuleList([
Expert(input_size, expert_size)
for _ in range(num_experts)
])
# 每个任务的门控网络
self.gates = nn.ModuleList([
GatingNetwork(input_size, num_experts)
for _ in range(num_tasks)
])
# 任务特定塔
self.towers = nn.ModuleList([
TaskTower(expert_size, tower_size, output_size)
for _ in range(num_tasks)
])
def forward(self, x):
# 专家输出
expert_outputs = [expert(x) for expert in self.experts]
expert_outputs = torch.stack(expert_outputs, dim=1) # [batch_size, num_experts, expert_size]
final_outputs = []
for task_id, (gate, tower) in enumerate(zip(self.gates, self.towers)):
# 获取门控权重
gate_weights = gate(x).unsqueeze(-1) # [batch_size, num_experts, 1]
# 加权组合专家输出
combined_experts = (expert_outputs * gate_weights).sum(dim=1) # [batch_size, expert_size]
# 通过任务塔
task_output = tower(combined_experts)
final_outputs.append(task_output)
return final_outputs
使用示例:
# 模型参数
input_size = 128
num_experts = 4
expert_size = 64
num_tasks = 2
tower_size = 32
output_size = 1
# 初始化模型
model = MMoE(
input_size=input_size,
num_experts=num_experts,
expert_size=expert_size,
num_tasks=num_tasks,
tower_size=tower_size,
output_size=output_size
)
# 模拟输入数据
batch_size = 32
x = torch.randn(batch_size, input_size)
# 前向传播
outputs = model(x)
task1_output, task2_output = outputs
训练代码:
class MMoETrainer:
def __init__(self, model, task_weights=[1.0, 1.0]):
self.model = model
self.task_weights = task_weights
self.optimizer = torch.optim.Adam(model.parameters())
self.criterion = nn.MSELoss()
def train_step(self, x, y1, y2):
self.optimizer.zero_grad()
# 前向传播
outputs = self.model(x)
task1_output, task2_output = outputs
# 计算损失
loss1 = self.criterion(task1_output, y1)
loss2 = self.criterion(task2_output, y2)
total_loss = self.task_weights[0] * loss1 + self.task_weights[1] * loss2
# 反向传播
total_loss.backward()
self.optimizer.step()
return {
'total_loss': total_loss.item(),
'task1_loss': loss1.item(),
'task2_loss': loss2.item()
}
# 训练循环
trainer = MMoETrainer(model)
for epoch in range(num_epochs):
for batch in dataloader:
x, y1, y2 = batch
losses = trainer.train_step(x, y1, y2)
print(f"Epoch {epoch}, Losses: {losses}")
高级特性:
class EnhancedMMoE(nn.Module):
def __init__(self, config):
super(EnhancedMMoE, self).__init__()
# 添加特征处理层
self.feature_processor = nn.Sequential(
nn.BatchNorm1d(config.input_size),
nn.Dropout(config.dropout_rate)
)
# 专家网络使用残差连接
self.experts = nn.ModuleList([
ResidualExpert(config)
for _ in range(config.num_experts)
])
# 注意力增强的门控网络
self.gates = nn.ModuleList([
AttentionGate(config)
for _ in range(config.num_tasks)
])
# 多层任务塔
self.towers = nn.ModuleList([
DeepTaskTower(config)
for _ in range(config.num_tasks)
])
def forward(self, x):
x = self.feature_processor(x)
# ... 其余逻辑类似基础版本
应用场景:
- 推荐系统:
- 点击预测
- 转化预测
- 停留时间预测
- 广告系统:
- CTR预测
- CVR预测
- ROI预测
- 用户行为分析:
- 用户兴趣预测
- 用户活跃度预测
- 用户生命周期预测
– 公众号持续更新:“北北文的自留地”