AF3 AlphaFoldWrapper类解读

AlphaFold3  AlphaFoldWrapper类 继承自 LightningModule,它是 AlphaFold3 的训练、验证和优化管理模块,主要作用是:

  • 管理模型训练/验证流程
  • 实现损失计算和日志记录
  • 集成指数滑动平均 (EMA)
  • 优化器和学习率调度
  • 在训练/验证过程中进行参数管理

源代码:

class AlphaFoldWrapper(LightningModule):
    def __init__(self, config):
        super(AlphaFoldWrapper, self).__init__()
        self.config = config
        self.globals = self.config.globals
        self.model = AlphaFold3(config)

        self.loss = AlphaFold3Loss(config.loss)

        self.ema = ExponentialMovingAverage(
            model=self.model, decay=config.ema_decay
        )
        self.cached_weights = None

        self.cached_weights = None
        self.last_lr_step = -1
        self.save_hyperparameters()

        # Set matmul precision
        torch.set_float32_matmul_precision(self.globals.matmul_precision)

    def forward(self, batch, training=True):
        return self.model(batch, train=training)

    def _log(self, batch, outputs, loss_breakdown=None, train=True):
        # Loop over loss values and log it
        phase = "train" if train else "val"
        if loss_breakdown is not None:
            for loss_name, indiv_loss in loss_breakdown.items():
                self.log(
                    f"{phase}/{loss_name}",
                    indiv_loss,
                    prog_bar=(loss_name == 'loss'),
                    on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
                )

        # Compute validation metrics
        other_metrics = self._compute_validation_metrics(
            batch,
            outputs,
            superimposition_metrics=True  # (not train)
        )

        for k, v in other_metrics.items():
            self.log(
                f"{phase}/{k}",
                torch.mean(v),
                prog_bar=(k == 'loss'),
                on_step=train, on_epoch=True, logger=True, sync_dist=True,
            )

    def training_step(self, batch, batch_idx):
        batch = reshape_features(batch)  # temporary

        # Run the model
        outputs = self.forward(batch, training=True)

        # Remove the recycling dimension
        batch = tensor_tree_map(lambda t: t[..., -1], batch)

        # For multimer, multichain permutation align the batch

        # Flatten the S to be incorporated into the batch dimension
        # TODO: this is temporary, will be removed once the data pipeline is better written
        outputs["augmented_gt_atoms"] = rearrange(
            outputs["augmented_gt_atoms"], 'b s n c -> (b s) n c'
        )
        outputs["denoised_atoms"] = rearrange(
            outputs["denoised_atoms"], 'b s n c -> (b s) n c'
        )
        outputs["timesteps"] = rearrange(
            outputs["timesteps"], 'b s o -> (b s) o'
        )
        # Expand atom_exists to be of shape (bs * samples_per_trunk, n_atoms)
        samples_per_trunk = outputs["timesteps"].shape[0] // batch["atom_exists"].shape[0]
        expand_batch = lambda tensor: tensor.repeat_interleave(samples_per_trunk, dim=0)
        batch["atom_exists"] = expand_batch(batch["atom_exists"])

        # Compute loss
        loss, loss_breakdown = self.loss(
            outputs, batch, _return_breakdown=True
        )

        # Log loss and validation metrics
        self._log(
            loss_breakdown=loss_breakdown,
            batch=batch,
            outputs=outputs,
            train=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        batch = reshape_features(batch)  # temporary

        # Run the model
        outputs = self.forward(batch, training=False)
        batch = tensor_tree_map(lambda t: t[..., -1], batch)  # Remove recycling dimension

        # For multimer, multichain permutation align the batch

        # Compute and log validation metrics
        self._log(loss_breakdown=None, batch=batch, outputs=outputs, train=False)

    def _compute_validation_metrics(
            self,
            batch,
            outputs,
            superimposition_metrics=False
    ):
        """Compute validation metrics for the model."""
        with torch.no_grad():
            batch_size, n_tokens = batch["token_index"].shape
            metrics = {}

            gt_coords = batch["all_atom_positions"]  # (bs, n_atoms, 3) gt_atoms after augmentation
            pred_coords = outputs["sampled_positions"].squeeze(-3)  # remove S dimension (bs, 1, n_atoms, 3)
            all_atom_mask = batch["atom_mask"]  # (bs, n_atoms)

            # Center the gt_coords
            gt_coords = gt_coords - torch.mean(gt_coords, dim=-2, keepdim=True)

            gt_coords_masked = gt_coords * all_atom_mask[..., None]
            pred_coords_masked = pred_coords * all_atom_mask[..., None]

            # Reshape to backbone atom format (temporary, will switch to more general representation)
            gt_coords_masked = gt_coords_masked.reshape(batch_size, n_tokens, 4, 3)
            pred_coords_masked = pred_coords_masked.reshape(batch_size, n_tokens, 4, 3)
            all_atom_mask = all_atom_mask.reshape(batch_size, n_tokens, 4)

            ca_pos = residue_constants.atom_order["CA"]
            gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
            pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
            all_atom_mask_ca = all_atom_mask[..., ca_pos]

            # lddt_ca_score = lddt(
            #    all_atom_pred_pos=pred_coords_masked_ca,
            #    all_atom_positions=gt_coords_masked_ca,
            #    all_atom_mask=all_atom_mask_ca,
            #    eps=self.config.globals.eps,
            #    per_residue=False
            # )
            # metrics["lddt_ca"] = lddt_ca_score

            # drmsd
            drmsd_ca_score = drmsd(
                pred_coords_masked_ca,
                gt_coords_masked_ca,
                mask=all_atom_mask_ca,  # still required here to compute n
            )
            metrics["drmsd_ca"] = drmsd_ca_score

            if superimposition_metrics:
                superimposed_pred, alignment_rmsd = superimpose(
                    gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
                )
                gdt_ts_score = gdt_ts(
                    superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
                )
                gdt_ha_score = gdt_ha(
                    superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
                )

                metrics["alignment_rmsd"] = alignment_rmsd
                metrics["gdt_ts"] = gdt_ts_score
                metrics["gdt_ha"] = gdt_ha_score

            return metrics

    def configure_optimizers(self):
        partial_optimizer = hydra.utils.instantiate(self.config.optimizer)
        partial_scheduler = hydra.utils.instantiate(self.config.scheduler)
        optimizer = partial_optimizer(self.trainer.model.parameters())
        scheduler = partial_scheduler(optimizer=optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "name": "AlphaFold3LRScheduler"
                # "frequency": 1,
            },
        }

    # def on_before_optimizer_step(self, optimizer):
    #    """Keeps an eye on gradient norms during training."""
    #    norms = grad_norm(self.model, norm_type=2)
    #    self.log_dict(norms)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """
        Keeps an eye on weight norms during training.
        """
        # Log weight norms
        weight_norms = {}
        for name, param in self.named_parameters():
            weight_norms[f"{name}_abs_mean"] = param.abs().mean().item()
        self.log_dict(weight_norms)

    def on_before_optimizer_step(self, optimizer):
        """Keeps an eye on gradient norms during training."""
        norms = grad_norm(self.model, norm_type=2)
        self.log_dict(norms)

    def on_train_batch_end(self, outputs, batch, batch_idx):
        # Update EMA after each training batch
        self.ema.update(self.model)
    
    def on_train_batch_start(self, batch: Any, batch_idx: int):
        # Fetch the EMA weights to the device
        if self.ema.device != batch["residue_index"].device:
            self.ema.to(batch["residue_index"].device)

    def on_validation_epoch_start(self):
        # At the start of validation, load the EMA weights
        if self.cached_weights is None:
            # model.state_dict() contains references to model weights rather
            # than copies. Therefore, we need to clone them before calling 
            # load_state_dict().
            clone_param = lambda t: t.detach().clone()
            self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
            self.model.load_state_dict(self.ema.params)

    def on_validation_epoch_end(self):
        # Restore original model weights
        if self.cached_weights is not None:
            self.model.load_state_dict(self.cached_weights)
            self.cached_weights = None

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
        """Lightning hook that is called when loading a checkpoint."""
        ema = checkpoint["ema"]
        self.ema.load_state_dict(ema)

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
        """Lightning hook that is called when saving a checkpoint."""
        checkpoint["ema"] = self.ema.state_dict()

    def resume_last_lr_step(self, lr_step):
        self.last_lr_step = lr_step

代码解读

1. __init__: 初始化
def __init__(self, config):
    super(AlphaFoldWrapper, self).__init__()
    self.config = config
    self.globals = self.config.globals
    self.model = AlphaFold3(config)

    self.loss = AlphaFold3Loss(config.loss)

    self.ema = ExponentialMovingAverage(
        model=self.model, decay=config.ema_decay
    )
    self.cached_weights = None

    self.last_lr_step = -1
    self.save_hyperparameters()

    # Set matmul precision
    torch.set_float32_matmul_precision(self.globals.matmul_precision)
  1. 初始化 AlphaFold3 模型

    • self.model = AlphaFold3(config)
  2. 加载损失计算模块

    • self.loss = AlphaFold3Loss(config.loss)
  3. 启用指数滑动平均 (EMA)

    • self.ema = ExponentialMovingAverage(model=self.model, decay=config.ema_decay)
    • EMA 作用:平滑参数更新,提高泛化能力
  4. 管理参数

    • self.cached_weights = None
    • self.last_lr_step = -1
    • self.save_hyperparameters(),可以将所有传递给 __init__ 的参数保存到 self.hparams 属性中。
  5. 设定矩阵乘法精度

    • torch.set_float32_matmul_precision(self.globals.matmul_precision)
2. forward: 前向传播
def forward(self, batch, training=True):
    return self.model(batch, train=training)
  • 封装 AlphaFold3 的 forward 方法
  • training 参数 控制是否使用训练模式
3. _log: 记录训练/验证信息
def _log(self, batch, outputs, loss_breakdown=None, train=True):
    phase = "train" if train else "val"
    if loss_breakdown is not None:
        for loss_name, indiv_loss in loss_breakdown.items():
            self.log(
                f"{phase}/{loss_name}",
                indiv_loss,
                prog_bar=(loss_name == 'loss'),
                on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
            )
    
    other_metrics = self._compute_validation_metrics(
        batch,
        outputs,
        superimposition_metrics=True  # (not train)
    )

    for k, v in other_metrics.items():
        self.log(
            f"{phase}/{k}",
            torch.mean(v),
            prog_bar=(k == 'loss'),
            on_step=train, on_epoch=True, logger=True, sync_dist=True,
        )
  • 记录损失 (loss_breakdown)
  • 计算并记录验证指标
    • _compute_validation_metrics() 计算 GDT_TS、GDT_HA、dRMSD 评分
    • self.log() 记录指标
4. training_step: 训练单步
def training_step(self, batch, batch_idx):
    batch = reshape_features(batch)  # 预处理 batch 数据

    outputs = self.forward(batch, training=True)

    # 移除循环维度(AlphaFold 采用循环优化策略)
    batch = tensor_tree_map(lambda t: t[..., -1], batch)

    # 调整 tensor 形状
    outputs["augmented_gt_atoms"] = rearrange(outputs["augmented_gt_atoms"], 'b s n c -> (b s) n c')
    outputs["denoised_atoms"] = rearrange(outputs["denoised_atoms"], 'b s n c -> (b s) n c')
    outputs["timesteps"] = rearrange(outputs["timesteps"], 'b s o -> (b s) o')

    # 计算损失
    loss, loss_breakdown = self.loss(outputs, batch, _return_breakdown=True)

    # 记录损失
    self._log(loss_breakdown=loss_breakdown, batch=batch, outputs=outputs, train=True)

    return loss
  • 运行前向传播
  • 调整数据格式
  • 计算损失
  • 记录损失
  • 返回 loss
5. validation_step: 验证
def validation_step(self, batch, batch_idx):
    batch = reshape_features(batch)

    outputs = self.forward(batch, training=False)
    batch = tensor_tree_map(lambda t: t[..., -1], batch)

    self._log(loss_breakdown=None, batch=batch, outputs=outputs, train=False)
  • 运行 forward,计算预测结果
  • 记录损失和评估指标(但不计算梯度)
 6. _compute_validation_metrics: 计算评估指标
def _compute_validation_metrics(self, batch, outputs, superimposition_metrics=False):
    with torch.no_grad():
        batch_size, n_tokens = batch["token_index"].shape
        metrics = {}

        gt_coords = batch["all_atom_positions"]
        pred_coords = outputs["sampled_positions"].squeeze(-3)
        all_atom_mask = batch["atom_mask"]

        gt_coords = gt_coords - torch.mean(gt_coords, dim=-2, keepdim=True)

        gt_coords_masked = gt_coords * all_atom_mask[..., None]
        pred_coords_masked = pred_coords * all_atom_mask[..., None]

        gt_coords_masked = gt_coords_masked.reshape(batch_size, n_tokens, 4, 3)
        pred_coords_masked = pred_coords_masked.reshape(batch_size, n_tokens, 4, 3)
        all_atom_mask = all_atom_mask.reshape(batch_size, n_tokens, 4)

        ca_pos = residue_constants.atom_order["CA"]
        gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
        pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
        all_atom_mask_ca = all_atom_mask[..., ca_pos]

        # 计算 dRMSD
        drmsd_ca_score = drmsd(
            pred_coords_masked_ca,
            gt_coords_masked_ca,
            mask=all_atom_mask_ca
        )
        metrics["drmsd_ca"] = drmsd_ca_score

        if superimposition_metrics:
            superimposed_pred, alignment_rmsd = superimpose(
                gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca
            )
            gdt_ts_score = gdt_ts(superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca)
            gdt_ha_score = gdt_ha(superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca)

            metrics["alignment_rmsd"] = alignment_rmsd
            metrics["gdt_ts"] = gdt_ts_score
            metrics["gdt_ha"] = gdt_ha_score

        return metrics
  • 计算对齐后的 GDT 评分
  • 计算 dRMSD
  • 计算 RMSD(对齐误差)
  • 适用于 AlphaFold3 的蛋白质结构评估
 7. configure_optimizers: 设置优化器
def configure_optimizers(self):
    partial_optimizer = hydra.utils.instantiate(self.config.optimizer)
    partial_scheduler = hydra.utils.instantiate(self.config.scheduler)
    optimizer = partial_optimizer(self.trainer.model.parameters())
    scheduler = partial_scheduler(optimizer=optimizer)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "interval": "step",
            "name": "AlphaFold3LRScheduler"
        },
    }
  • 使用 Hydra 配置优化器
  • 支持自定义学习率调度器

总结

  • 封装了 AlphaFold3 训练/验证逻辑
  • 集成损失计算、GDT 评分、dRMSD
  • 使用 EMA 平滑参数
  • 采用 Pytorch Lightning 进行训练管理

AlphaFoldWrapper 类主要作用是整合 AlphaFold3 训练、验证和优化流程,使得模型训练更加稳定、高效。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值