出错的版本:
def compute_balance_loss(self, routing_probs):
"""
计算平衡损失(balance_loss)。
"""
balance_loss = -torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1).mean()
return balance_loss
def forward(self, x, gate_inputs):
batch_size = x.size(0)
# Compute routing probabilities
routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1) # Shape (batch_size, num_experts)
# Select the expert with the highest probability for each sample
top1_expert_indices = routing_probs.argmax(dim=-1) # Shape (batch_size,)
# Initialize the output tensor
expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
# Apply the selected expert
for i in range(self.num_experts):
expert_mask = top1_expert_indices == i
if expert_mask.any():
expert_input = x[expert_mask]
expert_output = self.experts[i](expert_input)
# Make a copy of the relevant part of routing_probs to avoid in-place operations on shared tensors
routing_prob_copy = routing_probs[expert_mask, i].unsqueeze(-1).clone()
expert_outputs[expert_mask] = expert_output * routing_prob_copy
# Compute the balance loss separately using the function
balance_loss = self.compute_balance_loss(routing_probs)
return expert_outputs, balance_loss
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
没有出错的版本:
def compute_balance_loss(self, routing_probs):
"""
计算平衡损失(balance_loss)。
"""
balance_loss = -torch.mean(torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1))
return balance_loss
def forward(self, x, gate_inputs):
batch_size = x.size(0)
# 计算路由概率
routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1) # Shape (batch_size, num_experts)
# 通过对每个专家的权重进行加权平均来计算最终的输出
expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
for i, expert in enumerate(self.experts):
# 获取当前专家的输出
expert_output = expert(x) # Shape (batch_size, expert_hidden_dim)
# 根据路由概率对输出进行加权
routing_prob = routing_probs[:, i].unsqueeze(-1) # Shape (batch_size, 1)
expert_outputs += expert_output * routing_prob # 加权求和
# 计算平衡损失
balance_loss = self.compute_balance_loss(routing_probs)
return expert_outputs, balance_loss
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.