2021SC@SDUSC
前言:这一篇让我们看看folding.py
InvariantPointAttention
class InvariantPointAttention(hk.Module):
def __init__(self,
config,
global_config,
dist_epsilon=1e-8,
name='invariant_point_attention'):
super().__init__(name=name)
self._dist_epsilon = dist_epsilon
self._zero_initialize_last = global_config.zero_init
self.config = config
self.global_config = global_config
这个类是是负责构建不变点注意模块的。高层次的想法是这个注意力模块在一组点上工作以及 3D 空间中的相关方向(例如蛋白质残基)。每个残基输出一组查询和键作为它们本地的点的参考范围。然后将注意力定义为欧几里得距离在全局框架中的查询和键之间。
上述代码段就是在完成初始化,初始化函数的参数解释如下:
config:结构模块配置
global_config:模型的全局配置。
dist_epsilon:小值以避免在距离计算中出现 NaN。
name:Haiku 模块。
def __call__(self, inputs_1d, inputs_2d, mask, affine):
num_residues, _ = inputs_1d.shape
这个函数的作用是给定一组查询残基(由仿射和相关标量定义)特征,这个函数计算几何感知注意力,查询残基和目标残基。残差在其局部参考系中产生点,被转换成全局框架,以便通过计算注意力欧氏距离。等效地,目标残基在它们的局部框架中产生点是用作注意力值,将其转换为查询残差本地帧。鉴于前面已经介绍过欧氏距离,这里就不过多赘述了。
这个函数的参数介绍如下:
inputs_1d:(N,C)一维输入嵌入,这是标量查询。
inputs_2d: (N, M, C’) 2D 输入嵌入,用于偏差和值。
mask:(N, 1)mask以指示 input_1d 的哪些元素参与注意。
affine:描述位置和方向的 QuatAffine 对象,是input_1d 中的每个元素。
num_head = self.config.num_head
num_scalar_qk = self.config.num_scalar_qk
num_point_qk = self.config.num_point_qk
num_scalar_v = self.config.num_scalar_v
num_point_v = self.config.num_point_v
num_output = self.config.num_channel
上述代码是通过删除大量的 ‘self’ 来提高可读性。
assert num_scalar_qk > 0
assert num_point_qk > 0
assert num_point_v > 0
这里就是一些断言函数
q_scalar = common_modules.Linear(
num_head * num_scalar_qk, name='q_scalar')(
inputs_1d)
q_scalar = jnp.reshape(
q_scalar, [num_residues, num_head, num_scalar_qk])
这是在构造形状为[num_query_residues, num_head, num_points]的标量查询
kv_scalar = common_modules.Linear(
num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')(
inputs_1d)
kv_scalar = jnp.reshape(kv_scalar,
[num_residues, num_head,
num_scalar_v + num_scalar_qk])
k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)
这是在Construct scalar keys/values of shape:[num_target_residues, num_head, num_points]
q_point_local = common_modules.Linear(
num_head * 3 * num_point_qk, name='q_point_local')(
inputs_1d)
q_point_local = jnp.split(q_point_local, 3, axis=-1)
q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
q_point = [
jnp.reshape(x, [num_residues, num_head, num_point_qk])
for x in q_point_global]
这是在构造形状是[num_residues, num_head, num_point_qk]的查询点
q_point_local = common_modules.Linear(
num_head * 3 * num_point_qk, name='q_point_local')(
inputs_1d)
q_point_local = jnp.split(q_point_local, 3, axis=-1)
其中上述代码段首先在局部框架中构造查询点
q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
其次这是项目查询指向全局框架
q_point = [
jnp.reshape(x, [num_residues, num_head, num_point_qk])
for x in q_point_global]
最后这是重塑查询点以备后用
kv_point_local = common_modules.Linear(
num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')(
inputs_1d)
kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
kv_point_global = [
jnp.reshape(x, [num_residues,
num_head, (num_point_qk + num_point_v)])
for x in kv_point_global]
k_point, v_point = list(
zip(*[
jnp.split(x, [num_point_qk,], axis=-1)
for x in kv_point_global
]))
这一大段代码是要构造键值点,其中Key points的形状 为[num_residues, num_head, num_point_qk],
Value points形状为 [num_residues, num_head, num_point_v]
scalar_variance = max(num_scalar_qk, 1) * 1.
这是在假设所有查询和键都来自 N(0, 1) 分布,计算注意力对数的方差。每个标量对 (q, k) 贡献 Var q*k = 1
point_variance = max(num_point_qk, 1) * 9. / 2
Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
num_logit_terms = 3
scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))
上述代码段是为标量、点和注意力二维部分分配相等的方差,以便总和为 1。
trainable_point_weights = jax.nn.softplus(hk.get_parameter(
'trainable_point_weights', shape=[num_head],
# softplus^{-1} (1)
init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)
v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]
q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
dist2 = [
squared_difference(qx[:, :, None, :], kx[:, None, :, :])
for qx, kx in zip(q_point, k_point)
]
dist2 = sum(dist2)
attn_qk_point = -0.5 * jnp.sum(
point_weights[:, None, None, :] * dist2, axis=-1)
v = jnp.swapaxes(v_scalar, -2, -3)
q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
k = jnp.swapaxes(k_scalar, -2, -3)
attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
attn_logits = attn_qk_scalar + attn_qk_point
attention_2d = common_modules.Linear(
num_head, name='attention_2d')(
inputs_2d)
attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
attention_2d = attention_2d_weights * attention_2d
attn_logits += attention_2d
mask_2d = mask * jnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d)
attn = jax.nn.softmax(attn_logits)
result_scalar = jnp.matmul(attn, v)
这一大段代码说明了可训练的点的每头权重
result_point_global = [jnp.sum(
attn[:, :, :, None] * vx[:, None, :, :],
axis=-2) for vx in v_point]
result_scalar = jnp.swapaxes(result_scalar, -2, -3)
result_point_global = [
jnp.swapaxes(x, -2, -3)
for x in result_point_global]
output_features = []
result_scalar = jnp.reshape(
result_scalar, [num_residues, num_head * num_scalar_v])
output_features.append(result_scalar)
result_point_global = [
jnp.reshape(r, [num_residues, num_head * num_point_v])
for r in result_point_global]
result_point_local = affine.invert_point(result_point_global, extra_dims=1)
output_features.extend(result_point_local)
output_features.append(jnp.sqrt(self._dist_epsilon +
jnp.square(result_point_local[0]) +
jnp.square(result_point_local[1]) +
jnp.square(result_point_local[2])))
这一大段是在对于点结果,手动实现matmul,使其成为一个float32在TPU上。这相当于result_point_global = [jnp.einsum(‘bhqk,bhkc->bhqc’, attn, vx) for vx in v_point]
但在 TPU 上,执行乘法和缩减总和可确保计算发生在 float32 而不是 bfloat16。
其中 output_features = []是储存线性输出投影中使用的特征,这点从后面不断地append应该可以看出来
result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
num_out = num_head * result_attention_over_2d.shape[-1]
output_features.append(
jnp.reshape(result_attention_over_2d,
[num_residues, num_out]))
final_init = 'zeros' if self._zero_initialize_last else 'linear'
final_act = jnp.concatenate(output_features, axis=-1)
收缩发生在第二个残差维度上,类似于执行通常的注意力。