AlphaFold2代码阅读(二)

2021SC@SDUSC


前言

接着上一篇,今天继续深挖modules.py,让我们徜徉在晦涩难懂的代码中吧。


1、apply_dropout

def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):
  if is_training and rate != 0.0:
    shape = list(tensor.shape)
    if broadcast_dim is not None:
      shape[broadcast_dim] = 1
    keep_rate = 1.0 - rate
    keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)
    return keep * tensor / keep_rate
  else:
    return tensor

这是上图的代码是在对tensor(张量)应用dropout。要想弄懂这个东西,首先要明确什么是张量和dropout。
(1)张量
经过我的学习,我觉得对张量最简单的描述,可以简单地认为张量就是一个n维的集合或者数组。
0维张量/标量 标量是一个数字

1维张量/向量 1维张量称为“向量”。

2维张量 2维张量称为矩阵

3维张量 公用数据存储在张量 时间序列数据 股价 文本数据 彩色图片(RGB)

(2)dropout
Dropout可以作为训练深度神经网络的一种trick供选择。在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。这种方式可以减少特征检测器(隐层节点)间的相互作用,检测器相互作用是指某些检测器依赖其他检测器才能发挥作用。Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更强,因为它不会太依赖某些局部的特征,。

2、class AlphaFold

在这里插入图片描述

class AlphaFold(hk.Module):

  def __init__(self, config, name='alphafold'):
    super().__init__(name=name)
    self.config = config
    self.global_config = config.global_config
    

  之前就提到过modules.py中class AlphaFoldIteration实现一次model的执行,着这里的class AlphaFold则是实现了recycling。根据论文中所写,一共recycling了三次。

 def __call__(
      self,
      batch,
      is_training,
      compute_loss=False,
      ensemble_representations=False,
      return_representations=False):

    impl = AlphaFoldIteration(self.config, self.global_config)
    batch_size, num_residues = batch['aatype'].shape

    def get_prev(ret):
      new_prev = {
          'prev_pos':
              ret['structure_module']['final_atom_positions'],
          'prev_msa_first_row': ret['representations']['msa_first_row'],
          'prev_pair': ret['representations']['pair'],
      }
      return jax.tree_map(jax.lax.stop_gradient, new_prev)

    def do_call(prev,
                recycle_idx,
                compute_loss=compute_loss):
      if self.config.resample_msa_in_recycling:
        num_ensemble = batch_size // (self.config.num_recycle + 1)
        def slice_recycle_idx(x):
          start = recycle_idx * num_ensemble
          size = num_ensemble
          return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
        ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
      else:
        num_ensemble = batch_size
        ensembled_batch = batch

      non_ensembled_batch = jax.tree_map(lambda x: x, prev)

      return impl(
          ensembled_batch=ensembled_batch,
          non_ensembled_batch=non_ensembled_batch,
          is_training=is_training,
          compute_loss=compute_loss,
          ensemble_representations=ensemble_representations)

    if self.config.num_recycle:
      emb_config = self.config.embeddings_and_evoformer
      prev = {
          'prev_pos': jnp.zeros(
              [num_residues, residue_constants.atom_type_num, 3]),
          'prev_msa_first_row': jnp.zeros(
              [num_residues, emb_config.msa_channel]),
          'prev_pair': jnp.zeros(
              [num_residues, num_residues, emb_config.pair_channel]),
      }



这是要用来运行一个AlphaFold model的
参数及返回值解释:
参数:
batch:带有 AlphaFold 模型输入的字典。
is_training:系统是否处于训练或推理模式。
compute_loss:是否计算损失(需要在批处理中提供额外的功能,并了解真实的结构)。
ensemble_representations:是否使用表示的集成。
return_representations:是否还返回中间表示。
返回:
当 compute_loss 为 True 时: 返回AlphaFoldIteration 的损失和输出元组。
当 compute_loss 为 False 时:只是 返回AlphaFoldIteration 的输出。
AlphaFoldIteration 的输出是一个嵌套字典,其中包含来自不同头部的预测。

      if 'num_iter_recycling' in batch:
        num_iter = batch['num_iter_recycling'][0]

上图中每个ensemble batch的值都是一样的,所以无所谓,随便取就行。num_iter = batch[‘num_iter_recycling’][0]这里是取值取了第0个。

num_iter = jnp.minimum(num_iter, self.config.num_recycle)

回收比模型配置运行。

      else:
        # Eval mode or tests: use the maximum number of iterations.
        num_iter = self.config.num_recycle

      body = lambda x: (x[0] + 1,  # pylint: disable=g-long-lambda
                        get_prev(do_call(x[1], recycle_idx=x[0],
                                         compute_loss=False)))

评估模式或测试:使用最大迭代次数。

      if hk.running_init():
        _, prev = body((0, prev))
      else:
        _, prev = hk.while_loop(
            lambda x: x[0] < num_iter,
            body,
            (0, prev))

初始化 Haiku 模块时,运行一次迭代。
while_loop 用于初始化body 中使用的Haiku 模块。

3、MSA多重对比序列

因为在上述代码中和论文中都提到了MSA,所以我就学习了一下,究竟什么是MSA。

(1)MSA介绍
MSA全称为multiple sequence alignment。多对比序列就是把两个以上序列对齐,逐列比较其字符的异同,使得每一列的字符尽可能一致,以发现其共同的结构特征。MSA的目标是使得参与比对的序列有尽可能多的列具有相同的字符,使得相同残基的位点位于同一列,以便发现不同序列之间的相似部分,从而推测他们结构和功能上面的相似关系。

(2)MSA的意义

(i).系统发育分析
用于描述同源序列之间的亲缘关系的远近,应用到分子演化分析中。是构建分子演化树的基础。

(ii).功能分析
用于描述一组序列之间的相似关系,以便了解一个基因家族的基本特征,寻找motif、保守区域等。用于预测新序列的二级和三级结构,进而推测其生物学功能。

(iii).突变分析
用于揭示不同个体的基因组由于突变而产生的差异。
不同物种基因组范围的MSA能分析基因组结构变异和共线性。

(iv).测序分析
用于获得共性序列;用于序列拼接。


总结

了解过MSA后,下周结合论文就能看有关MSA部分的代码了

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值