2021SC@SDUSC
前言
接着上一篇,今天继续深挖modules.py,让我们徜徉在晦涩难懂的代码中吧。
1、apply_dropout
def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):
if is_training and rate != 0.0:
shape = list(tensor.shape)
if broadcast_dim is not None:
shape[broadcast_dim] = 1
keep_rate = 1.0 - rate
keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)
return keep * tensor / keep_rate
else:
return tensor
这是上图的代码是在对tensor(张量)应用dropout。要想弄懂这个东西,首先要明确什么是张量和dropout。
(1)张量
经过我的学习,我觉得对张量最简单的描述,可以简单地认为张量就是一个n维的集合或者数组。
0维张量/标量 标量是一个数字
1维张量/向量 1维张量称为“向量”。
2维张量 2维张量称为矩阵
3维张量 公用数据存储在张量 时间序列数据 股价 文本数据 彩色图片(RGB)
(2)dropout
Dropout可以作为训练深度神经网络的一种trick供选择。在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。这种方式可以减少特征检测器(隐层节点)间的相互作用,检测器相互作用是指某些检测器依赖其他检测器才能发挥作用。Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更强,因为它不会太依赖某些局部的特征,。
2、class AlphaFold
class AlphaFold(hk.Module):
def __init__(self, config, name='alphafold'):
super().__init__(name=name)
self.config = config
self.global_config = config.global_config
之前就提到过modules.py中class AlphaFoldIteration实现一次model的执行,着这里的class AlphaFold则是实现了recycling。根据论文中所写,一共recycling了三次。
def __call__(
self,
batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
impl = AlphaFoldIteration(self.config, self.global_config)
batch_size, num_residues = batch['aatype'].shape
def get_prev(ret):
new_prev = {
'prev_pos':
ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
return jax.tree_map(jax.lax.stop_gradient, new_prev)
def do_call(prev,
recycle_idx,
compute_loss=compute_loss):
if self.config.resample_msa_in_recycling:
num_ensemble = batch_size // (self.config.num_recycle + 1)
def slice_recycle_idx(x):
start = recycle_idx * num_ensemble
size = num_ensemble
return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
else:
num_ensemble = batch_size
ensembled_batch = batch
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(
ensembled_batch=ensembled_batch,
non_ensembled_batch=non_ensembled_batch,
is_training=is_training,
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
if self.config.num_recycle:
emb_config = self.config.embeddings_and_evoformer
prev = {
'prev_pos': jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3]),
'prev_msa_first_row': jnp.zeros(
[num_residues, emb_config.msa_channel]),
'prev_pair': jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel]),
}
这是要用来运行一个AlphaFold model的
参数及返回值解释:
参数:
batch:带有 AlphaFold 模型输入的字典。
is_training:系统是否处于训练或推理模式。
compute_loss:是否计算损失(需要在批处理中提供额外的功能,并了解真实的结构)。
ensemble_representations:是否使用表示的集成。
return_representations:是否还返回中间表示。
返回:
当 compute_loss 为 True 时: 返回AlphaFoldIteration 的损失和输出元组。
当 compute_loss 为 False 时:只是 返回AlphaFoldIteration 的输出。
AlphaFoldIteration 的输出是一个嵌套字典,其中包含来自不同头部的预测。
if 'num_iter_recycling' in batch:
num_iter = batch['num_iter_recycling'][0]
上图中每个ensemble batch的值都是一样的,所以无所谓,随便取就行。num_iter = batch[‘num_iter_recycling’][0]这里是取值取了第0个。
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
回收比模型配置运行。
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
body = lambda x: (x[0] + 1, # pylint: disable=g-long-lambda
get_prev(do_call(x[1], recycle_idx=x[0],
compute_loss=False)))
评估模式或测试:使用最大迭代次数。
if hk.running_init():
_, prev = body((0, prev))
else:
_, prev = hk.while_loop(
lambda x: x[0] < num_iter,
body,
(0, prev))
初始化 Haiku 模块时,运行一次迭代。
while_loop 用于初始化body
中使用的Haiku 模块。
3、MSA多重对比序列
因为在上述代码中和论文中都提到了MSA,所以我就学习了一下,究竟什么是MSA。
(1)MSA介绍
MSA全称为multiple sequence alignment。多对比序列就是把两个以上序列对齐,逐列比较其字符的异同,使得每一列的字符尽可能一致,以发现其共同的结构特征。MSA的目标是使得参与比对的序列有尽可能多的列具有相同的字符,使得相同残基的位点位于同一列,以便发现不同序列之间的相似部分,从而推测他们结构和功能上面的相似关系。
(2)MSA的意义
(i).系统发育分析
用于描述同源序列之间的亲缘关系的远近,应用到分子演化分析中。是构建分子演化树的基础。
(ii).功能分析
用于描述一组序列之间的相似关系,以便了解一个基因家族的基本特征,寻找motif、保守区域等。用于预测新序列的二级和三级结构,进而推测其生物学功能。
(iii).突变分析
用于揭示不同个体的基因组由于突变而产生的差异。
不同物种基因组范围的MSA能分析基因组结构变异和共线性。
(iv).测序分析
用于获得共性序列;用于序列拼接。
总结
了解过MSA后,下周结合论文就能看有关MSA部分的代码了