import jax
import jax.numpy as jnp
import haiku as hk
def stable_softmax(logits: jax.Array) -> jax.Array:
"""Numerically stable softmax for (potential) bfloat 16."""
if logits.dtype == jnp.float32:
output = jax.nn.softmax(logits)
elif logits.dtype == jnp.bfloat16:
# Need to explicitly do softmax in float32 to avoid numerical issues
# with large negatives. Large negatives can occur if trying to mask
# by adding on large negative logits so that things softmax to zero.
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
else:
raise ValueError(f'Unexpected input dtype {logits.dtype}')
return output
def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
"""Masked mean."""
if drop_mask_channel:
# 最后一维取第一个,维度减少一个
mask = mask[..., 0]
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(value_shape)
# axis表示在哪一维进行mask后的均值计算
if isinstance(axis, numbers.Integral):
axis = [axis]
# 当axis为None:对所有维度进行mask后的均值计算,函数返回标量
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(axis, collections.abc.Iterable), (
'axis needs to be either an iterable, integer or "None"')
broadcast_factor = 1.
for axis_ in axis:
value_size = value_shape[axis_]
mask_size = mask_shape[axis_]
if mask_size == 1:
broadcast_factor *= value_size
else:
assert mask_size == value_size
# mask * value时,可能进行了broadcasting,所以计算mean是要除了jnp.sum(mask, axis=axis) * broadcast_factor
# 加上eps是为了防止除数为0
return (jnp.sum(mask * value, axis=axis) /
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))
def glorot_uniform():
return hk.initializers.VarianceScaling(scale=1.0,
mode='fan_avg',
distribution='uniform')
class GlobalAttention(hk.Module):
"""Global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7
"""
def __init__(self, config, global_config, output_dim, name='attention'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, q_mask):
"""Builds GlobalAttention module.
Arguments:
q_data: A tensor of queries with size [batch_size, N_queries,
q_channels]
m_data: A tensor of memories from which the keys and values
projected. Size [batch_size, N_keys, m_channels]
q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape).
Returns:
A float32 tensor of size [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
num_head = self.config.num_head
assert key_dim % num_head == 0
assert value_dim % num_head == 0
key_dim = key_dim // num_head
value_dim = value_dim // num_head
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
k_weights = hk.get_parameter(
'key_w', shape=(m_data.shape[-1], key_dim),
dtype=q_data.dtype,
init=glorot_uniform())
v_weights = hk.get_parameter(
'value_w', shape=(m_data.shape[-1], value_dim),
dtype=q_data.dtype,
init=glorot_uniform())
# value值,单个注意力头
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
# 一条序列中所有张量(去除mask序列)的平均后再计算q值
# q_avg size: [batch_size, N_queries]
q_avg = mask_mean(q_mask, q_data, axis=1)
# query值,对整个msa做线性变换,多个注意力头
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
# key值,分别对msa中的每行(一条序列)进行线性变换 a -> c
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
# 第二个维度上插入一个新的维度,第四个维度上取索引为 0 的切片(query sequence)
# bias维度 (batch_size, 1, sequence_length)
bias = q_mask[:, None, :, 0]
# Attention类中的注意力分数 jnp.einsum('bqhc,bkhc->bhqk', q, k)
logits = jnp.einsum('bhc,bkc->bhk', q, k)
# bias中padded位置(值为0)用_SOFTMAX_MASK值取代,非0位置用logits对应位置值替代
# 注意这一步的意义:query_seq的padding位置,注意力值取_SOFTMAX_MASK
logits = jnp.where(bias, logits, _SOFTMAX_MASK)
weights = stable_softmax(logits)
# 和Attention类比较,缺少q(序列)这一维度,对整个msa做attention
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
init = glorot_uniform()
o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
init=init)
o_bias = hk.get_parameter(
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
if self.config.gating:
gating_weights = hk.get_parameter(
'gating_w',
shape=(q_data.shape[-1], num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))
gating_bias = hk.get_parameter(
'gating_b',
shape=(num_head, value_dim),
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))
gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights)
gate_values = jax.nn.sigmoid(gate_values + gating_bias)
# weighted_avg增加一维,在和gate_values(对应元素)相乘
weighted_avg = weighted_avg[:, None] * gate_values
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
else:
output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias
output = output[:, None]
return output
haiku实现MSA全局注意力模块
最新推荐文章于 2024-06-05 18:19:25 发布