【梯度更新出错】还没找到原因

出错的版本:

    def compute_balance_loss(self, routing_probs):
        """
        计算平衡损失(balance_loss)。
        """
        balance_loss = -torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1).mean()
        return balance_loss

    def forward(self, x, gate_inputs):
        batch_size = x.size(0)
        
        # Compute routing probabilities
        routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1)  # Shape (batch_size, num_experts)
        
        # Select the expert with the highest probability for each sample
        top1_expert_indices = routing_probs.argmax(dim=-1)  # Shape (batch_size,)
        
        # Initialize the output tensor
        expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
        
        # Apply the selected expert
        for i in range(self.num_experts):
            expert_mask = top1_expert_indices == i
            if expert_mask.any():
                expert_input = x[expert_mask]
                expert_output = self.experts[i](expert_input)
                
                # Make a copy of the relevant part of routing_probs to avoid in-place operations on shared tensors
                routing_prob_copy = routing_probs[expert_mask, i].unsqueeze(-1).clone()
                expert_outputs[expert_mask] = expert_output * routing_prob_copy
        
        # Compute the balance loss separately using the function
        balance_loss = self.compute_balance_loss(routing_probs)
        
        return expert_outputs, balance_loss

没有出错的版本:

    def compute_balance_loss(self, routing_probs):
        """
        计算平衡损失(balance_loss)。
        """
        balance_loss = -torch.mean(torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1))
        return balance_loss

    def forward(self, x, gate_inputs):
        batch_size = x.size(0)
        
        # 计算路由概率
        routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1)  # Shape (batch_size, num_experts)
        
        # 通过对每个专家的权重进行加权平均来计算最终的输出
        expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
        
        for i, expert in enumerate(self.experts):
            # 获取当前专家的输出
            expert_output = expert(x)  # Shape (batch_size, expert_hidden_dim)
            
            # 根据路由概率对输出进行加权
            routing_prob = routing_probs[:, i].unsqueeze(-1)  # Shape (batch_size, 1)
            expert_outputs += expert_output * routing_prob  # 加权求和
        
        # 计算平衡损失
        balance_loss = self.compute_balance_loss(routing_probs)
        
        return expert_outputs, balance_loss
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值