TriangleAttention 模块调用Attention模块,需要计算nonbatched_bias,加到注意力分数logits中(见Attention代码)。TriangleAttentionStartingNode 是对 pair_act进行行注意力计算。TriangleAttentionEndingNode 是对pair_act进行列注意力计算。
import jax
import haiku as hk
import jax.numpy as jnp
class TriangleAttention(hk.Module):
"""Triangle Attention.
Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
"""
def __init__(self, config, global_config, name='triangle_attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, pair_act, pair_mask, is_training=False):
"""Builds TriangleAttention module.
Arguments:
pair_act: [N_res, N_res, c_z] pair activations tensor
pair_mask: [N_res, N_res] mask of non-padded regions in the tensor.
is_training: Whether the module is in training mode.
Returns:
Update to pair_act, shape [N_res, N_res, c_z].
"""
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
mask = pair_mask[:, None, None, :]
assert len(mask.shape) == 4
pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
pair_act)
init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1]))
weights = hk.get_parameter(
'feat_2d_weights',
shape=(pair_act.shape[-1], c.num_head),
dtype=pair_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)
attn_mod = Attention(
c, self.global_config, pair_act.shape[-1])
pair_act = inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[pair_act, pair_act, mask],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
return pair_act