AlphaFold3 AlphaFoldWrapper类
继承自 LightningModule
,它是 AlphaFold3 的训练、验证和优化管理模块,主要作用是:
- 管理模型训练/验证流程
- 实现损失计算和日志记录
- 集成指数滑动平均 (EMA)
- 优化器和学习率调度
- 在训练/验证过程中进行参数管理
源代码:
class AlphaFoldWrapper(LightningModule):
def __init__(self, config):
super(AlphaFoldWrapper, self).__init__()
self.config = config
self.globals = self.config.globals
self.model = AlphaFold3(config)
self.loss = AlphaFold3Loss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema_decay
)
self.cached_weights = None
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters()
# Set matmul precision
torch.set_float32_matmul_precision(self.globals.matmul_precision)
def forward(self, batch, training=True):
return self.model(batch, train=training)
def _log(self, batch, outputs, loss_breakdown=None, train=True):
# Loop over loss values and log it
phase = "train" if train else "val"
if loss_breakdown is not None:
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
)
# Compute validation metrics
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=True # (not train)
)
for k, v in other_metrics.items():
self.log(
f"{phase}/{k}",
torch.mean(v),
prog_bar=(k == 'loss'),
on_step=train, on_epoch=True, logger=True, sync_dist=True,
)
def training_step(self, batch, batch_idx):
batch = reshape_features(batch) # temporary
# Run the model
outputs = self.forward(batch, training=True)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# For multimer, multichain permutation align the batch
# Flatten the S to be incorporated into the batch dimension
# TODO: this is temporary, will be removed once the data pipeline is better written
outputs["augmented_gt_atoms"] = rearrange(
outputs["augmented_gt_atoms"], 'b s n c -> (b s) n c'
)
outputs["denoised_atoms"] = rearrange(
outputs["denoised_atoms"], 'b s n c -> (b s) n c'
)
outputs["timesteps"] = rearrange(
outputs["timesteps"], 'b s o -> (b s) o'
)
# Expand atom_exists to be of shape (bs * samples_per_trunk, n_atoms)
samples_per_trunk = outputs["timesteps"].shape[0] // batch["atom_exists"].shape[0]
expand_batch = lambda tensor: tensor.repeat_interleave(samples_per_trunk, dim=0)
batch["atom_exists"] = expand_batch(batch["atom_exists"])
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log loss and validation metrics
self._log(
loss_breakdown=loss_breakdown,
batch=batch,
outputs=outputs,
train=True
)
return loss
def validation_step(self, batch, batch_idx):
batch = reshape_features(batch) # temporary
# Run the model
outputs = self.forward(batch, training=False)
batch = tensor_tree_map(lambda t: t[..., -1], batch) # Remove recycling dimension
# For multimer, multichain permutation align the batch
# Compute and log validation metrics
self._log(loss_breakdown=None, batch=batch, outputs=outputs, train=False)
def _compute_validation_metrics(
self,
batch,
outputs,
superimposition_metrics=False
):
"""Compute validation metrics for the model."""
with torch.no_grad():
batch_size, n_tokens = batch["token_index"].shape
metrics = {}
gt_coords = batch["all_atom_positions"] # (bs, n_atoms, 3) gt_atoms after augmentation
pred_coords = outputs["sampled_positions"].squeeze(-3) # remove S dimension (bs, 1, n_atoms, 3)
all_atom_mask = batch["atom_mask"] # (bs, n_atoms)
# Center the gt_coords
gt_coords = gt_coords - torch.mean(gt_coords, dim=-2, keepdim=True)
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
# Reshape to backbone atom format (temporary, will switch to more general representation)
gt_coords_masked = gt_coords_masked.reshape(batch_size, n_tokens, 4, 3)
pred_coords_masked = pred_coords_masked.reshape(batch_size, n_tokens, 4, 3)
all_atom_mask = all_atom_mask.reshape(batch_size, n_tokens, 4)
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
# lddt_ca_score = lddt(
# all_atom_pred_pos=pred_coords_masked_ca,
# all_atom_positions=gt_coords_masked_ca,
# all_atom_mask=all_atom_mask_ca,
# eps=self.config.globals.eps,
# per_residue=False
# )
# metrics["lddt_ca"] = lddt_ca_score
# drmsd
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if superimposition_metrics:
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
def configure_optimizers(self):
partial_optimizer = hydra.utils.instantiate(self.config.optimizer)
partial_scheduler = hydra.utils.instantiate(self.config.scheduler)
optimizer = partial_optimizer(self.trainer.model.parameters())
scheduler = partial_scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"name": "AlphaFold3LRScheduler"
# "frequency": 1,
},
}
# def on_before_optimizer_step(self, optimizer):
# """Keeps an eye on gradient norms during training."""
# norms = grad_norm(self.model, norm_type=2)
# self.log_dict(norms)
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""
Keeps an eye on weight norms during training.
"""
# Log weight norms
weight_norms = {}
for name, param in self.named_parameters():
weight_norms[f"{name}_abs_mean"] = param.abs().mean().item()
self.log_dict(weight_norms)
def on_before_optimizer_step(self, optimizer):
"""Keeps an eye on gradient norms during training."""
norms = grad_norm(self.model, norm_type=2)
self.log_dict(norms)
def on_train_batch_end(self, outputs, batch, batch_idx):
# Update EMA after each training batch
self.ema.update(self.model)
def on_train_batch_start(self, batch: Any, batch_idx: int):
# Fetch the EMA weights to the device
if self.ema.device != batch["residue_index"].device:
self.ema.to(batch["residue_index"].device)
def on_validation_epoch_start(self):
# At the start of validation, load the EMA weights
if self.cached_weights is None:
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.params)
def on_validation_epoch_end(self):
# Restore original model weights
if self.cached_weights is not None:
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
"""Lightning hook that is called when loading a checkpoint."""
ema = checkpoint["ema"]
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
"""Lightning hook that is called when saving a checkpoint."""
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
代码解读
1. __init__
: 初始化
def __init__(self, config):
super(AlphaFoldWrapper, self).__init__()
self.config = config
self.globals = self.config.globals
self.model = AlphaFold3(config)
self.loss = AlphaFold3Loss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema_decay
)
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters()
# Set matmul precision
torch.set_float32_matmul_precision(self.globals.matmul_precision)
-
初始化
AlphaFold3
模型self.model = AlphaFold3(config)
-
加载损失计算模块
self.loss = AlphaFold3Loss(config.loss)
-
启用指数滑动平均 (EMA)
self.ema = ExponentialMovingAverage(model=self.model, decay=config.ema_decay)
- EMA 作用:平滑参数更新,提高泛化能力
-
管理参数
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters()
,可以将所有传递给__init__
的参数保存到self.hparams
属性中。
-
设定矩阵乘法精度
torch.set_float32_matmul_precision(self.globals.matmul_precision)
2. forward
: 前向传播
def forward(self, batch, training=True):
return self.model(batch, train=training)
- 封装
AlphaFold3
的forward
方法 - training 参数 控制是否使用训练模式
3. _log
: 记录训练/验证信息
def _log(self, batch, outputs, loss_breakdown=None, train=True):
phase = "train" if train else "val"
if loss_breakdown is not None:
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
)
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=True # (not train)
)
for k, v in other_metrics.items():
self.log(
f"{phase}/{k}",
torch.mean(v),
prog_bar=(k == 'loss'),
on_step=train, on_epoch=True, logger=True, sync_dist=True,
)
- 记录损失 (
loss_breakdown
) - 计算并记录验证指标
_compute_validation_metrics()
计算 GDT_TS、GDT_HA、dRMSD 评分self.log()
记录指标
4. training_step
: 训练单步
def training_step(self, batch, batch_idx):
batch = reshape_features(batch) # 预处理 batch 数据
outputs = self.forward(batch, training=True)
# 移除循环维度(AlphaFold 采用循环优化策略)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# 调整 tensor 形状
outputs["augmented_gt_atoms"] = rearrange(outputs["augmented_gt_atoms"], 'b s n c -> (b s) n c')
outputs["denoised_atoms"] = rearrange(outputs["denoised_atoms"], 'b s n c -> (b s) n c')
outputs["timesteps"] = rearrange(outputs["timesteps"], 'b s o -> (b s) o')
# 计算损失
loss, loss_breakdown = self.loss(outputs, batch, _return_breakdown=True)
# 记录损失
self._log(loss_breakdown=loss_breakdown, batch=batch, outputs=outputs, train=True)
return loss
- 运行前向传播
- 调整数据格式
- 计算损失
- 记录损失
- 返回 loss
5. validation_step
: 验证
def validation_step(self, batch, batch_idx):
batch = reshape_features(batch)
outputs = self.forward(batch, training=False)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
self._log(loss_breakdown=None, batch=batch, outputs=outputs, train=False)
- 运行
forward
,计算预测结果 - 记录损失和评估指标(但不计算梯度)
6. _compute_validation_metrics
: 计算评估指标
def _compute_validation_metrics(self, batch, outputs, superimposition_metrics=False):
with torch.no_grad():
batch_size, n_tokens = batch["token_index"].shape
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["sampled_positions"].squeeze(-3)
all_atom_mask = batch["atom_mask"]
gt_coords = gt_coords - torch.mean(gt_coords, dim=-2, keepdim=True)
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
gt_coords_masked = gt_coords_masked.reshape(batch_size, n_tokens, 4, 3)
pred_coords_masked = pred_coords_masked.reshape(batch_size, n_tokens, 4, 3)
all_atom_mask = all_atom_mask.reshape(batch_size, n_tokens, 4)
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
# 计算 dRMSD
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca
)
metrics["drmsd_ca"] = drmsd_ca_score
if superimposition_metrics:
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca
)
gdt_ts_score = gdt_ts(superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca)
gdt_ha_score = gdt_ha(superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
- 计算对齐后的 GDT 评分
- 计算 dRMSD
- 计算 RMSD(对齐误差)
- 适用于 AlphaFold3 的蛋白质结构评估
7. configure_optimizers
: 设置优化器
def configure_optimizers(self):
partial_optimizer = hydra.utils.instantiate(self.config.optimizer)
partial_scheduler = hydra.utils.instantiate(self.config.scheduler)
optimizer = partial_optimizer(self.trainer.model.parameters())
scheduler = partial_scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"name": "AlphaFold3LRScheduler"
},
}
- 使用 Hydra 配置优化器
- 支持自定义学习率调度器
总结
- 封装了 AlphaFold3 训练/验证逻辑
- 集成损失计算、GDT 评分、dRMSD
- 使用 EMA 平滑参数
- 采用 Pytorch Lightning 进行训练管理
AlphaFoldWrapper 类主要作用是整合 AlphaFold3 训练、验证和优化流程,使得模型训练更加稳定、高效。