AlphaFold2代码阅读(七)

2021SC@SDUSC


1.class TriangleMultiplication源码

class TriangleMultiplication(hk.Module):

  def __init__(self, config, global_config, name='triangle_multiplication'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, act, mask, is_training=True):
  
    del is_training
    c = self.config
    gc = self.global_config

    mask = mask[..., None]

    act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
                       name='layer_norm_input')(act)
    input_act = act

    left_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='left_projection')
    left_proj_act = mask * left_projection(act)

    right_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='right_projection')
    right_proj_act = mask * right_projection(act)

    left_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='left_gate')(act))

    right_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='right_gate')(act))

    left_proj_act *= left_gate_values
    right_proj_act *= right_gate_values

   
    act = jnp.einsum(c.equation, left_proj_act, right_proj_act)

    act = hk.LayerNorm(
        axis=[-1],
        create_scale=True,
        create_offset=True,
        name='center_layer_norm')(
            act)

    output_channel = int(input_act.shape[-1])

    act = common_modules.Linear(
        output_channel,
        initializer=utils.final_init(gc),
        name='output_projection')(act)

    gate_values = jax.nn.sigmoid(common_modules.Linear(
        output_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='gating_linear')(input_act))
    act *= gate_values

    return act

这个类描述的是Triangle multiplication layer的设置
其中 def call(self, act, mask, is_training=True):函数是builds TriangleMultiplication modulede的作用
这个函数中的参数分别是:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.

2.流程图和伪代码

三角形的乘法更新如下图,通过结合图边ij、ik和jk的每个三角形中的信息来更新Evoformer block中的对表示。每条边i所j从有三角形的其他两条边接收一个更新,其中涉及到它。
在这里插入图片描述
上图对应的伪代码如下图在这里插入图片描述
将上面两张图结合起来看,流程就很清晰了:
zij先经过LayerNorm。然后zij经过Linear又进过sigmoid分别和zij直接Linear后的结果进行乘法操作,分别得到了aij和bij,对应于流程图中标记好的1和2位置的结果。
gij是zij经过Linear后又经过sigmoid后的结果,对应于流程图中的3的位置上。
将aij和bij进过乘法操作后再求和,再经过LayerNorm和Linear后和上面gij进行乘法操作后就得到了zij(hat),也就对用于流程图中的4的位置

3.结合流程图和伪代码的源码分析

    act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
                       name='layer_norm_input')(act)

这一句对应的是伪代码中的第一行,先经过LayerNorm的处理后得到新的zij

left_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='left_projection')
    left_proj_act = mask * left_projection(act)

    right_projection = common_modules.Linear(
        c.num_intermediate_channel,
        name='right_projection')
    right_proj_act = mask * right_projection(act)

上述代码对应的是流程图中最下面的两条线,zij进过Linear后分别得到left edges和right edges

left_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='left_gate')(act))

    right_gate_values = jax.nn.sigmoid(common_modules.Linear(
        c.num_intermediate_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='right_gate')(act))

上述的这两段代码对应着流程图中的最上面分支的前两条线,zij进过Linear后又经过sigmoid后的结果

    left_proj_act *= left_gate_values
    right_proj_act *= right_gate_values

上述代码是对应着流程图中最前面的两个乘法操作,这两个乘法操作做完后就得到了aij和bij,对应着流程图中1和2的位置的结果,对应着伪代码中第二句的意思

这里要注意对于outgoing 边的话,流程图中的a对应着代码中的 left_proj_act ,流程图中的b对应着 right_proj_act
但是对于incoming边,则是反过来,流程图中的a对应着代码中的 right_proj_act ,流程图中的b对应着 left_proj_act

  act = jnp.einsum(c.equation, left_proj_act, right_proj_act)

这代码对应着流程图中求导sum r的那一步

    act = hk.LayerNorm(
        axis=[-1],
        create_scale=True,
        create_offset=True,
        name='center_layer_norm')(
            act)

这里是将得到sum后的结果又输入进LayerNorm中,得到新的act

    act = common_modules.Linear(
        output_channel,
        initializer=utils.final_init(gc),
        name='output_projection')(act)

上述代码试讲刚刚经历过LayerNorm后的act再进行Linear处理

    gate_values = jax.nn.sigmoid(common_modules.Linear(
        output_channel,
        bias_init=1.,
        initializer=utils.final_init(gc),
        name='gating_linear')(input_act))
   

这是在执行流程图中上分支线中的第三条,先经过Linear后再进过sigmoid,得到gate_values

    act *= gate_values

    return act

这是在执行流程图中的最后一步乘法,将前面已有的act和gate_values相乘,然后得到流程图汇总4位置的结果,返回给act

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值