AlphaFold2代码阅读(十四)

2021SC@SDUSC


前言:这一篇让我们看看folding.py


InvariantPointAttention

class InvariantPointAttention(hk.Module):
  def __init__(self,
               config,
               global_config,
               dist_epsilon=1e-8,
               name='invariant_point_attention'):
 
    super().__init__(name=name)

    self._dist_epsilon = dist_epsilon
    self._zero_initialize_last = global_config.zero_init

    self.config = config

    self.global_config = global_config
               

  这个类是是负责构建不变点注意模块的。高层次的想法是这个注意力模块在一组点上工作以及 3D 空间中的相关方向(例如蛋白质残基)。每个残基输出一组查询和键作为它们本地的点的参考范围。然后将注意力定义为欧几里得距离在全局框架中的查询和键之间。

上述代码段就是在完成初始化,初始化函数的参数解释如下:

config:结构模块配置
global_config:模型的全局配置。
dist_epsilon:小值以避免在距离计算中出现 NaN。
name:Haiku 模块。

  def __call__(self, inputs_1d, inputs_2d, mask, affine):
   
    num_residues, _ = inputs_1d.shape

  这个函数的作用是给定一组查询残基(由仿射和相关标量定义)特征,这个函数计算几何感知注意力,查询残基和目标残基。残差在其局部参考系中产生点,被转换成全局框架,以便通过计算注意力欧氏距离。等效地,目标残基在它们的局部框架中产生点是用作注意力值,将其转换为查询残差本地帧。鉴于前面已经介绍过欧氏距离,这里就不过多赘述了。

这个函数的参数介绍如下:

inputs_1d:(N,C)一维输入嵌入,这是标量查询。
inputs_2d: (N, M, C’) 2D 输入嵌入,用于偏差和值。
mask:(N, 1)mask以指示 input_1d 的哪些元素参与注意。
affine:描述位置和方向的 QuatAffine 对象,是input_1d 中的每个元素。

    num_head = self.config.num_head
    num_scalar_qk = self.config.num_scalar_qk
    num_point_qk = self.config.num_point_qk
    num_scalar_v = self.config.num_scalar_v
    num_point_v = self.config.num_point_v
    num_output = self.config.num_channel

上述代码是通过删除大量的 ‘self’ 来提高可读性。

    assert num_scalar_qk > 0
    assert num_point_qk > 0
    assert num_point_v > 0

这里就是一些断言函数

    q_scalar = common_modules.Linear(
        num_head * num_scalar_qk, name='q_scalar')(
            inputs_1d)
    q_scalar = jnp.reshape(
        q_scalar, [num_residues, num_head, num_scalar_qk])

这是在构造形状为[num_query_residues, num_head, num_points]的标量查询

    kv_scalar = common_modules.Linear(
        num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')(
            inputs_1d)
    kv_scalar = jnp.reshape(kv_scalar,
                            [num_residues, num_head,
                             num_scalar_v + num_scalar_qk])
    k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1)

这是在Construct scalar keys/values of shape:[num_target_residues, num_head, num_points]

    q_point_local = common_modules.Linear(
        num_head * 3 * num_point_qk, name='q_point_local')(
            inputs_1d)
    q_point_local = jnp.split(q_point_local, 3, axis=-1)
    q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)
    q_point = [
        jnp.reshape(x, [num_residues, num_head, num_point_qk])
        for x in q_point_global]

这是在构造形状是[num_residues, num_head, num_point_qk]的查询点

    q_point_local = common_modules.Linear(
        num_head * 3 * num_point_qk, name='q_point_local')(
            inputs_1d)
    q_point_local = jnp.split(q_point_local, 3, axis=-1)

其中上述代码段首先在局部框架中构造查询点

    q_point_global = affine.apply_to_point(q_point_local, extra_dims=1)

其次这是项目查询指向全局框架

    q_point = [
        jnp.reshape(x, [num_residues, num_head, num_point_qk])
        for x in q_point_global]

最后这是重塑查询点以备后用

    kv_point_local = common_modules.Linear(
        num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')(
            inputs_1d)
    kv_point_local = jnp.split(kv_point_local, 3, axis=-1)
  
    kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1)
    kv_point_global = [
        jnp.reshape(x, [num_residues,
                        num_head, (num_point_qk + num_point_v)])
        for x in kv_point_global]

    k_point, v_point = list(
        zip(*[
            jnp.split(x, [num_point_qk,], axis=-1)
            for x in kv_point_global
        ]))

这一大段代码是要构造键值点,其中Key points的形状 为[num_residues, num_head, num_point_qk],
Value points形状为 [num_residues, num_head, num_point_v]

    scalar_variance = max(num_scalar_qk, 1) * 1.

这是在假设所有查询和键都来自 N(0, 1) 分布,计算注意力对数的方差。每个标量对 (q, k) 贡献 Var q*k = 1

    point_variance = max(num_point_qk, 1) * 9. / 2

Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2

    num_logit_terms = 3

    scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance))
    point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance))
    attention_2d_weights = np.sqrt(1.0 / (num_logit_terms))

上述代码段是为标量、点和注意力二维部分分配相等的方差,以便总和为 1。

    trainable_point_weights = jax.nn.softplus(hk.get_parameter(
        'trainable_point_weights', shape=[num_head],
        # softplus^{-1} (1)
        init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))))
    point_weights *= jnp.expand_dims(trainable_point_weights, axis=1)

    v_point = [jnp.swapaxes(x, -2, -3) for x in v_point]

    q_point = [jnp.swapaxes(x, -2, -3) for x in q_point]
    k_point = [jnp.swapaxes(x, -2, -3) for x in k_point]
    dist2 = [
        squared_difference(qx[:, :, None, :], kx[:, None, :, :])
        for qx, kx in zip(q_point, k_point)
    ]
    dist2 = sum(dist2)
    attn_qk_point = -0.5 * jnp.sum(
        point_weights[:, None, None, :] * dist2, axis=-1)

    v = jnp.swapaxes(v_scalar, -2, -3)
    q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3)
    k = jnp.swapaxes(k_scalar, -2, -3)
    attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    attn_logits = attn_qk_scalar + attn_qk_point

    attention_2d = common_modules.Linear(
        num_head, name='attention_2d')(
            inputs_2d)

    attention_2d = jnp.transpose(attention_2d, [2, 0, 1])
    attention_2d = attention_2d_weights * attention_2d
    attn_logits += attention_2d

    mask_2d = mask * jnp.swapaxes(mask, -1, -2)
    attn_logits -= 1e5 * (1. - mask_2d)
        attn = jax.nn.softmax(attn_logits)
    result_scalar = jnp.matmul(attn, v)

这一大段代码说明了可训练的点的每头权重

   result_point_global = [jnp.sum(
        attn[:, :, :, None] * vx[:, None, :, :],
        axis=-2) for vx in v_point]

    result_scalar = jnp.swapaxes(result_scalar, -2, -3)
    result_point_global = [
        jnp.swapaxes(x, -2, -3)
        for x in result_point_global]

    output_features = []

    result_scalar = jnp.reshape(
        result_scalar, [num_residues, num_head * num_scalar_v])
    output_features.append(result_scalar)

    result_point_global = [
        jnp.reshape(r, [num_residues, num_head * num_point_v])
        for r in result_point_global]
    result_point_local = affine.invert_point(result_point_global, extra_dims=1)
    output_features.extend(result_point_local)

    output_features.append(jnp.sqrt(self._dist_epsilon +
                                    jnp.square(result_point_local[0]) +
                                    jnp.square(result_point_local[1]) +
                                    jnp.square(result_point_local[2])))

这一大段是在对于点结果,手动实现matmul,使其成为一个float32在TPU上。这相当于result_point_global = [jnp.einsum(‘bhqk,bhkc->bhqc’, attn, vx) for vx in v_point]
但在 TPU 上,执行乘法和缩减总和可确保计算发生在 float32 而不是 bfloat16。
其中 output_features = []是储存线性输出投影中使用的特征,这点从后面不断地append应该可以看出来

    result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d)
    num_out = num_head * result_attention_over_2d.shape[-1]
    output_features.append(
        jnp.reshape(result_attention_over_2d,
                    [num_residues, num_out]))

    final_init = 'zeros' if self._zero_initialize_last else 'linear'

    final_act = jnp.concatenate(output_features, axis=-1)

收缩发生在第二个残差维度上,类似于执行通常的注意力。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值