MSAStack
类是 AlphaFold3 中多序列比对(MSA)处理的核心模块之一,主要作用是通过一系列操作对 MSA 特征进行加权平均、正则化和更新,从而提取更高质量的序列-序列关系特征。
源代码:
class MSAStack(nn.Module):
"""MSA stack that applies pair weighted averaging, dropout, and transition."""
def __init__(
self,
c_msa: int,
c_z: int,
c_hidden: int = 8,
no_heads: int = 8,
dropout: float = 0.15,
inf: float = 1e8
):
super(MSAStack, self).__init__()
self.msa_pair_avg = MSAPairWeightedAveraging(
c_msa=c_msa,
c_z=c_z,
c_hidden=c_hidden,
no_heads=no_heads,
inf=inf
)
self.dropout_row_layer = DropoutRowwise(dropout)
self.transition = Transition(c_msa)
def forward(
self,
m: Tensor,
z: