MSAColumnAttention类和MSAColumnGlobalAttention类分别调用了Attention类和GlobalAttention类。输入msa特征数据在做attenion之前,先进行了行列转化从而实现对列数据进行注意力计算,最后再转化回来得到msa特征新的表示。
import haiku as hk
import jax
import jax.numpy as jnp
class MSAColumnAttention(hk.Module):
"""MSA per-column attention.
Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
"""
def __init__(self, config, global_config, name='msa_column_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m]
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
# jnp.swapaxes:交换数组(张量)的轴的函数。
# 交换后对msa做attention,就相当于对msa列做attention
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
mask = msa_mask[:, None, None, :]
assert len(mask.shape) == 4
msa_act = LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, mask],
nonbatched_args=[],
low_memory=not is_training)
# msa_act的第二个轴和第三个轴再交换回来
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act
class MSAColumnGlobalAttention(hk.Module):
"""MSA per-column global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
"""
def __init__(self, config, global_config, name='msa_column_global_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self,
msa_act,
msa_mask,
is_training=False):
"""Builds MSAColumnGlobalAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c = self.config
assert len(msa_act.shape) == 3
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_column'
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
msa_act = LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
attn_mod = GlobalAttention(
c, self.global_config, msa_act.shape[-1],
name='attention')
# [N_seq, N_res, 1]
msa_mask = jnp.expand_dims(msa_mask, axis=-1)
msa_act = inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, msa_mask],
nonbatched_args=[],
low_memory=not is_training)
msa_act = jnp.swapaxes(msa_act, -2, -3)
return msa_act