AF3 AlphaFold3Loss类解读

AlphaFold3 AlphaFold3Loss 损失函数类定义了AlphaFold3 的总损失函数,用于 聚合多个子损失项,并在前向传播时计算最终的损失值。

源代码:

class AlphaFold3Loss(nn.Module):
    """Aggregation of various losses described in the supplement."""

    def __init__(self, config):
        super(AlphaFold3Loss, self).__init__()
        self.config = config

    def _compute_losses(self, out, batch):
        losses = {}
        
        # Distogram loss
        losses["distogram"] = distogram_loss(
            logits=out["distogram_logits"],
            **{**batch, **self.config.distogram}
        )
        
        # Smooth LDDT loss
        losses["smooth_lddt_loss"] = smooth_lddt_loss(
            pred_atoms=out["denoised_atoms"],
            gt_atoms=out["augmented_gt_atoms"],
            atom_is_rna=torch.zeros_like(batch["atom_exists"]),
            atom_is_dna=torch.zeros_like(batch["atom_exists"]),
            mask=batch["atom_exists"],
        )
        
        # MSE loss
        losses["mse_loss"] = mse_loss(
            pred_atoms=out["denoised_atoms"],
            gt_atoms=out["augmented_gt_atoms"],
            timesteps=out["timesteps"],
            weights=batch["atom_exists"],
            mask=batch["atom_exists"],
            **self.config.mse_loss,
        )
        return losses

    def _aggregate_losses(self, losses):
        """Aggregate the losses with their respective weights."""
        cumulative_loss = 0.0
        for loss_name, loss in losses.items():
            weight = self.config[loss_name].weight
            if torch.isnan(loss):
                logging.warning(f"{loss_name} loss is NaN. Skipping...")
                loss = loss.new_tensor(0., requires_grad=True)
            elif torch.isinf(loss):
                logging.warning(f"{loss_name} loss is inf. Skipping...")
                loss = loss.new_tensor(0., requires_grad=True)
            cumulative_loss = cumulative_loss + weight * loss
            losses[loss_name] = loss.detach().clone()
        
        losses["unscaled_loss"] = cumulative_loss.detach().clone()
        losses["loss"] = cumulative_loss.detach().clone()
        return cumulative_loss, losses

    def forward(self, out, batch, _return_breakdown=False):
        losses = self._compute_losses(out, batch)
        cumulative_loss, losses = self._aggregate_losses(losses)
        
        if _return_breakdown:
            return cumulative_loss, losses
        return cumulative_loss

代码解读:

1. 类的作用
class AlphaFold3Loss(nn.Module):

这个类继承了 torch.nn.Module,用于计算 AlphaFold3 训练过程中不同的损失项,并对它们进行加权求和,以得到最终的训练损失。

2. _compute_losses 方法
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值