AlphaFold3 AlphaFold3Loss
损失函数类定义了AlphaFold3 的总损失函数,用于 聚合多个子损失项,并在前向传播时计算最终的损失值。
源代码:
class AlphaFold3Loss(nn.Module):
"""Aggregation of various losses described in the supplement."""
def __init__(self, config):
super(AlphaFold3Loss, self).__init__()
self.config = config
def _compute_losses(self, out, batch):
losses = {}
# Distogram loss
losses["distogram"] = distogram_loss(
logits=out["distogram_logits"],
**{**batch, **self.config.distogram}
)
# Smooth LDDT loss
losses["smooth_lddt_loss"] = smooth_lddt_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"],
atom_is_rna=torch.zeros_like(batch["atom_exists"]),
atom_is_dna=torch.zeros_like(batch["atom_exists"]),
mask=batch["atom_exists"],
)
# MSE loss
losses["mse_loss"] = mse_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"],
timesteps=out["timesteps"],
weights=batch["atom_exists"],
mask=batch["atom_exists"],
**self.config.mse_loss,
)
return losses
def _aggregate_losses(self, losses):
"""Aggregate the losses with their respective weights."""
cumulative_loss = 0.0
for loss_name, loss in losses.items():
weight = self.config[loss_name].weight
if torch.isnan(loss):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
elif torch.isinf(loss):
logging.warning(f"{loss_name} loss is inf. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cumulative_loss = cumulative_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cumulative_loss.detach().clone()
losses["loss"] = cumulative_loss.detach().clone()
return cumulative_loss, losses
def forward(self, out, batch, _return_breakdown=False):
losses = self._compute_losses(out, batch)
cumulative_loss, losses = self._aggregate_losses(losses)
if _return_breakdown:
return cumulative_loss, losses
return cumulative_loss
代码解读:
1. 类的作用
class AlphaFold3Loss(nn.Module):
这个类继承了 torch.nn.Module
,用于计算 AlphaFold3 训练过程中不同的损失项,并对它们进行加权求和,以得到最终的训练损失。