2021SC@SDUSC
1.class TriangleMultiplication源码
class TriangleMultiplication(hk.Module):
def __init__(self, config, global_config, name='triangle_multiplication'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
del is_training
c = self.config
gc = self.global_config
mask = mask[..., None]
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
input_act = act
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
output_channel = int(input_act.shape[-1])
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
act *= gate_values
return act
这个类描述的是Triangle multiplication layer的设置
其中 def call(self, act, mask, is_training=True):函数是builds TriangleMultiplication modulede的作用
这个函数中的参数分别是:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
2.流程图和伪代码
三角形的乘法更新如下图,通过结合图边ij、ik和jk的每个三角形中的信息来更新Evoformer block中的对表示。每条边i所j从有三角形的其他两条边接收一个更新,其中涉及到它。
上图对应的伪代码如下图
将上面两张图结合起来看,流程就很清晰了:
zij先经过LayerNorm。然后zij经过Linear又进过sigmoid分别和zij直接Linear后的结果进行乘法操作,分别得到了aij和bij,对应于流程图中标记好的1和2位置的结果。
gij是zij经过Linear后又经过sigmoid后的结果,对应于流程图中的3的位置上。
将aij和bij进过乘法操作后再求和,再经过LayerNorm和Linear后和上面gij进行乘法操作后就得到了zij(hat),也就对用于流程图中的4的位置
3.结合流程图和伪代码的源码分析
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
name='layer_norm_input')(act)
这一句对应的是伪代码中的第一行,先经过LayerNorm的处理后得到新的zij
left_projection = common_modules.Linear(
c.num_intermediate_channel,
name='left_projection')
left_proj_act = mask * left_projection(act)
right_projection = common_modules.Linear(
c.num_intermediate_channel,
name='right_projection')
right_proj_act = mask * right_projection(act)
上述代码对应的是流程图中最下面的两条线,zij进过Linear后分别得到left edges和right edges
left_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='left_gate')(act))
right_gate_values = jax.nn.sigmoid(common_modules.Linear(
c.num_intermediate_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='right_gate')(act))
上述的这两段代码对应着流程图中的最上面分支的前两条线,zij进过Linear后又经过sigmoid后的结果
left_proj_act *= left_gate_values
right_proj_act *= right_gate_values
上述代码是对应着流程图中最前面的两个乘法操作,这两个乘法操作做完后就得到了aij和bij,对应着流程图中1和2的位置的结果,对应着伪代码中第二句的意思
这里要注意对于outgoing 边的话,流程图中的a对应着代码中的 left_proj_act ,流程图中的b对应着 right_proj_act
但是对于incoming边,则是反过来,流程图中的a对应着代码中的 right_proj_act ,流程图中的b对应着 left_proj_act
act = jnp.einsum(c.equation, left_proj_act, right_proj_act)
这代码对应着流程图中求导sum r的那一步
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset=True,
name='center_layer_norm')(
act)
这里是将得到sum后的结果又输入进LayerNorm中,得到新的act
act = common_modules.Linear(
output_channel,
initializer=utils.final_init(gc),
name='output_projection')(act)
上述代码试讲刚刚经历过LayerNorm后的act再进行Linear处理
gate_values = jax.nn.sigmoid(common_modules.Linear(
output_channel,
bias_init=1.,
initializer=utils.final_init(gc),
name='gating_linear')(input_act))
这是在执行流程图中上分支线中的第三条,先经过Linear后再进过sigmoid,得到gate_values
act *= gate_values
return act
这是在执行流程图中的最后一步乘法,将前面已有的act和gate_values相乘,然后得到流程图汇总4位置的结果,返回给act