AF3 MSAModule类源码解读

AlphaFold3 中的MSAModule 类 是一个用于处理多序列比对(MSA)的模块,核心功能是通过 MSAModuleBlock 堆叠和梯度检查点优化,实现对 MSA 表征和配对表征的高效计算。调用该类最终返回更新后的配对表征z,更新后的z含有MSA特征和目的蛋白质序列信息。

源代码:

class MSAModule(nn.Module):
    def __init__(
            self,
            no_blocks: int = 4,
            c_msa: int = 64,
            c_token: int = 384,
            c_z: int = 128,
            c_hidden: int = 32,
            no_heads: int = 8,
            c_hidden_tri_mul: int = 128,
            c_hidden_pair_attn: int = 32,
            no_heads_tri_attn: int = 4,
            transition_n: int = 4,
            pair_dropout: float = 0.25,
            fuse_projection_weights: bool = False,
            clear_cache_between_blocks: bool = False,
            blocks_per_ckpt: int = 1,
            inf: float = 1e8
    ):
        """
        Initialize the MSA module.
        Args:
            no_blocks:
                number of MSAModuleBlocks
            c_msa:
                MSA representation dim
            c_token:
                Single representation dim
            c_z:
                pair representation dim
            c_hidden:
                hidden representation dim
            no_heads:
                number of heads in the pair averaging
            c_hidden_tri_mul:
                hidden dimensionality of triangular multiplicative updates
            c_hidden_pair_attn:
                hidden dimensionality of triangular attention
            no_heads_tri_attn:
                number of heads in triangular attention
            transition_n:
                multiplication factor for the hidden dim during the transition
            pair_dropout:
                dropout rate within the pair stack
            fuse_projection_weights:
                whether to use FusedTriangleMultiplicativeUpdate or not
            blocks_per_ckpt:
                Number of blocks per checkpoint. If None, no checkpointing is used.
            clear_cache_between_blocks:
                Whether to clear CUDA's GPU memory cache between blocks of the
                stack. Slows down each block but can reduce fragmentation
        """
        super(MSAModule, self).__init__()
        self.blocks = nn.ModuleList([
            MSAModuleBlock(
                c_msa=c_msa,
                c_z=c_z,
                c_hidden=c_hidden,
                no_heads=no_heads,
                c_hidden_tri_mul=c_hidden_tri_mul,
                c_hidden_pair_attn=c_hidden_pair_attn,
                no_heads_tri_attn=no_heads_tri_attn,
                transition_n=transition_n,
                pair_dropout=pair_dropout,
                fuse_projection_weights=fuse_projection_weights,
                inf=inf)
            for _ in range(no_blocks)
 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值