AlphaFold3 中的MSAModule
类 是一个用于处理多序列比对(MSA)的模块,核心功能是通过 MSAModuleBlock
堆叠和梯度检查点优化,实现对 MSA 表征和配对表征的高效计算。调用该类最终返回更新后的配对表征z,更新后的z含有MSA特征和目的蛋白质序列信息。
源代码:
class MSAModule(nn.Module):
def __init__(
self,
no_blocks: int = 4,
c_msa: int = 64,
c_token: int = 384,
c_z: int = 128,
c_hidden: int = 32,
no_heads: int = 8,
c_hidden_tri_mul: int = 128,
c_hidden_pair_attn: int = 32,
no_heads_tri_attn: int = 4,
transition_n: int = 4,
pair_dropout: float = 0.25,
fuse_projection_weights: bool = False,
clear_cache_between_blocks: bool = False,
blocks_per_ckpt: int = 1,
inf: float = 1e8
):
"""
Initialize the MSA module.
Args:
no_blocks:
number of MSAModuleBlocks
c_msa:
MSA representation dim
c_token:
Single representation dim
c_z:
pair representation dim
c_hidden:
hidden representation dim
no_heads:
number of heads in the pair averaging
c_hidden_tri_mul:
hidden dimensionality of triangular multiplicative updates
c_hidden_pair_attn:
hidden dimensionality of triangular attention
no_heads_tri_attn:
number of heads in triangular attention
transition_n:
multiplication factor for the hidden dim during the transition
pair_dropout:
dropout rate within the pair stack
fuse_projection_weights:
whether to use FusedTriangleMultiplicativeUpdate or not
blocks_per_ckpt:
Number of blocks per checkpoint. If None, no checkpointing is used.
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(MSAModule, self).__init__()
self.blocks = nn.ModuleList([
MSAModuleBlock(
c_msa=c_msa,
c_z=c_z,
c_hidden=c_hidden,
no_heads=no_heads,
c_hidden_tri_mul=c_hidden_tri_mul,
c_hidden_pair_attn=c_hidden_pair_attn,
no_heads_tri_attn=no_heads_tri_attn,
transition_n=transition_n,
pair_dropout=pair_dropout,
fuse_projection_weights=fuse_projection_weights,
inf=inf)
for _ in range(no_blocks)