AlphaFold3的mse_loss函数
计算的是 基于扩散模型的 MSE 损失,用于 AlphaFold3 训练中的结构预测任务。
它的核心思想是:
- 计算 MSE(均方误差)损失,衡量预测原子坐标 (
pred_atoms
) 与真实原子坐标 (gt_atoms
) 的差距。 - 进行刚性对齐,避免不必要的旋转/平移影响损失计算。
- 引入时间步 (
timesteps
) 进行加权,根据扩散噪声水平调整 MSE 损失。
源代码:
def mse_loss(
pred_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
gt_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
timesteps: Tensor, # (bs * samples_per_trunk, 1)
weights: Tensor, # (bs, n_atoms)
mask: Optional[Tensor] = None, # (bs, n_atoms)
sd_data: float = 16.0, # Standard deviation of the data
epsilon: Optional[float] = 1e-5,
**kwargs
) -> Tensor: # (bs,)
"""Diffusion loss that scales the MSE and LDDT losses by the noise level (timestep)."""
# Convert to Vec3Array
pred_atoms = Vec3Array.from_array(pred_atoms)
gt_atoms = Vec3Array.from_array(gt_atoms)
# Align the gt_atoms to pred_atoms
aligned_gt_atoms = weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask)
# MSE loss
mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask)
# Scale by (t**2 + σ**2) / (t * σ)**2
scaling_factor = (timesteps ** 2 + sd_data ** 2) / ((timesteps * sd_data) ** 2 + epsilon)
scaled_mse = scaling_factor.squeeze(-1) * mse # (bs,)
# Average over batch dimension
return torch.mean(scaled_mse) # scaled_mse
1. 函数参数解析
def mse_loss(
pred_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
gt_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
timesteps: Tensor, # (bs * samples_per_trunk, 1)
weights: Tensor, # (bs, n_atoms)
mask: Optional[Tensor] = None, # (bs, n_atoms)
sd_data: float = 16.0, # Standard deviation of the data
epsilon: Optional[float] = 1e-5,
**kwargs
) -> T