TorchProtein教程--预训练的蛋白质结构表示(5)

TorchProtein教程–预训练的蛋白质结构表示(5)

本教程来自唐建团队的开源框架torchprotein

目录

在许多基于结构的蛋白质理解任务中,收集标记数据的时间和金钱都很昂贵。作为一种解决方案,提出了自监督预训练策略,以从大量未标记的蛋白质结构中获得信息丰富的蛋白质表示。在本教程中,我们将介绍如何预训练基于结构的蛋白质编码器,然后在下游任务中对其进行微调。

蛋白质结构数据表示

在本部分中,我们将学习如何获取基于结构的数据集进行预训练,并进一步用额外的边增强每个样本,以更好地表示其结构。

蛋白质结构数据集

让我们首先构建一个蛋白质结构数据集。为了提高效率,我们定义了一个基于datasets. enzymatic ommissiontoyEnzymeCommissionToy蛋白质结构数据集。此外,我们向数据集传递两个转换函数来截断过长的蛋白质并指定节点特征。

from torchdrug import datasets, transforms

# A toy protein structure dataset
class EnzymeCommissionToy(datasets.EnzymeCommission):
    url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/data/EnzymeCommission.tar.gz"
    md5 = "728e0625d1eb513fa9b7626e4d3bcf4d"
    processed_file = "enzyme_commission_toy.pkl.gz"
    test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95]

truncuate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view='residue')
transform = transforms.Compose([truncuate_transform, protein_view_transform])

dataset = EnzymeCommissionToy("~/protein-datasets/", transform=transform, atom_feature=None, bond_feature=None)
train_set, valid_set, test_set = dataset.split()
print(dataset)
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))

动态图构造

由RDKit构建的蛋白质数据只包含四种类型的键边(即single, double, triple或aromatic)。以数据集的第一个样本为例,我们挑选出前两个残基的原子,并显示它们之间的化学键。

from torchdrug import data

protein = dataset[0]["graph"]
is_first_two = (protein.residue_number == 1) | (protein.residue_number == 2)
first_two = protein.residue_mask(is_first_two, compact=True)
first_two.visualize()

为了更好地表示蛋白质结构,我们试图通过layers.GraphConstruction模块动态重建蛋白质图。对于节点,我们使用layers.geometry.AlphaCarbonNode从蛋白质中提取Alpha碳,以构建残差级图。对于节点,我们使用layers.geometry.AlphaCarbonNode从蛋白质中提取Alpha碳,以构建残差级图。对于边,我们使用layers.geometry.SpatialEdge, layers.geometry.KNNEdgelayers.geometry.SequentialEdge来构造不同残差之间的空间、KNN和顺序边(关于这些边的详细定义,请参阅教程3:基于结构的蛋白质属性预测)

from torchdrug import layers
from torchdrug.layers import geometry

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

_protein = data.Protein.pack([protein])
protein_ = graph_construction_model(_protein)
print("Graph before: ", _protein)
print("Graph after: ", protein_)

degree = protein_.degree_in + protein_.degree_out
print("Average degree: ", degree.mean())
print("Maximum degree: ", degree.max())
print("Minimum degree: ", degree.min())

经过这样的图构建,我们将蛋白质结构表示为残差级关系图。通过将空间边和KNN边作为两种类型的边,将5种不同顺序距离(即-2,-1,0,1和2)的顺序边作为5种类型的边,我们得到了具有7种不同边缘类型的关系图。每条边与一个59维边缘特征相关联,该特征是其两个端点节点的单热剩余特征、边缘类型、顺序距离和空间距离的拼接。

nodes_in, nodes_out, edges_type = protein_.edge_list.t()
residue_ids = protein_.residue_type.tolist()
for node_in, node_out, edge_type, edge_feature in zip(nodes_in.tolist()[:5], nodes_out.tolist()[:5], edges_type.tolist()[:5], protein_.edge_feature[:5]):
    print("[%s -> %s, type %d] edge feature shape: " % (data.Protein.id2residue[residue_ids[node_in]], 
                                                        data.Protein.id2residue[residue_ids[node_out]], edge_type), edge_feature.shape)

蛋白质结构表示模型

TorchProtein定义了多种GNN模型,可作为蛋白质结构编码器。在本教程中,我们采用了具有边缘消息传递的高级几何感知关系图神经网络(GearNet-Edge)。在TorchProtein中,我们可以用models.GearNet定义一个GearNet-Edge模型。

from torchdrug import models

gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512], 
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

自我监督蛋白结构预训练

在本教程中,我们采用了两种预训练算法,多视图对比学习和残留类型预测,从未标记的蛋白质结构中学习蛋白质表示。

多视图对比学习

多视图对比学习旨在最大化同一蛋白质的不同视图表示之间的相似性,同时最小化不同蛋白质之间的相似性。下图说明了多视图对比学习的高级思想。

我们首先将GearNet-Edge模型包装到模型中。在MultiviewContrast模块中,我们将增强函数传递给aug_funcs参数使用,并将裁剪函数传递给crop_funcs参数使用。这个模块在GearNet-Edge上追加了一个MLP预测头。在此基础上,将Multiview Contrast模块与图形构建模型打包到tasks.Unsupervised模块中进行自我监督预训练。

这里我们使用两种不同的裁剪函数:子空间和子序列。前者随机取一个长度不超过50的较短的连续子序列,而后者取一个球内的所有残基,其中心残基是随机选择的。在裁剪蛋白质后,我们随机选择是否在残基图中随机遮盖边缘作为一种增强。

from torchdrug import layers, models, tasks
from torchdrug.layers import geometry

model = models.MultiviewContrast(gearnet_edge, noise_funcs=[geometry.IdentityNode(), geometry.RandomEdgeMask(mask_rate=0.15)],
                                 crop_funcs=[geometry.SubsequenceNode(max_length=50), 
                                             geometry.SubspaceNode(entity_level="residue", min_neighbor=15, min_radius=15.0)], num_mlp_layer=2)
task = tasks.Unsupervised(model, graph_construction_model=graph_construction_model)

现在我们可以开始预训练了。我们为我们的模型设置了一个优化器,并将所有内容放在一个Engine实例中。在这个预训练任务上训练模型10个epoch大约需要5分钟。最后保存上一个epoch的模型权值。

from torchdrug import core

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=4)
solver.train(num_epoch=10)
solver.save("MultiviewContrast_ECToy.pth")

残基类型预测

残基类型预测是一种典型的自预测任务,它将输入残基级图中的一部分残基进行掩模,并根据蛋白质的上下文规律对被掩模的残基类型进行预测。这种方法也被称为蛋白质的掩模反折叠(预测给定结构的序列)。下图说明了残留物类型预测的高级思想。

为了完成这一任务,我们将GearNet-Edge模型和图形构造模型包装到tasks.AttributeMasking模块中,其中一个MLP预测头将附加在GearNet-Edge上。注意,这个模块也可以用来预训练分子编码器。模块将根据训练集中的图视图选择是预测原子类型还是残基类型。

task = tasks.AttributeMasking(gearnet_edge, graph_construction_model=graph_construction_model,
                              mask_rate=0.15, num_mlp_layer=2)

现在我们可以开始预训练了。与上面类似,我们为我们的模型设置了一个优化器,并将所有内容放在一个Engine实例中。在这个预训练任务上训练模型10个epoch大约需要8分钟。最后保存上一个epoch的模型权值。

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=4)
solver.train(num_epoch=10)
solver.save("ResidueTypePrediction_ECToy.pth")

下游任务的微调

我们使用玩具酶委员会数据集上的蛋白质功能项预测作为下游任务。这个任务的目的是预测一个蛋白质是否拥有几个特定的功能,其中每个功能的拥有可以用一个二进制标签来表示。因此,我们将该任务制定为多个二元分类任务,并寻求以多任务学习的方式共同解决它们。我们使用tasks.MultipleBinaryClassification模块来执行这项任务,该模块将GearNet-Edge模型与MLP预测头相结合。

task = tasks.MultipleBinaryClassification(gearnet_edge, graph_construction_model=graph_construction_model, num_mlp_layer=3,
                                          task=[_ for _ in range(len(dataset.tasks))], criterion="bce", metric=['auprc@micro', 'f1_max'])

从头开始训练

我们首先通过从头开始的训练来评估GearNet-Edge。在这个任务上训练模型10个epoch大约需要8分钟。最后在验证集上求值。

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=4)
solver.train(num_epoch=10)
solver.evaluate("valid")

微调多视图对比学习模型

然后,我们评估了多视图对比学习预训练的GearNet-Edge模型。我们用预先训练好的模型权重初始化GearNet-Edge。在这个任务上训练模型10个epoch大约需要8分钟。

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=4)

_checkpoint = torch.load("MultiviewContrast_ECToy.pth")["model"]
checkpoint = {}
for k, v in _checkpoint.items():
    if k.startswith("model.model"):
        checkpoint[k[6:]] = v
    else:
        checkpoint[k] = v
checkpoint = {k: v for k, v in checkpoint.items() if not k.startswith("mlp")}
task.load_state_dict(checkpoint, strict=False)

solver.train(num_epoch=10)
solver.evaluate("valid")

微调残基类型预测模型

然后评估残基类型预测预训练的GearNet-Edge模型。在这个任务上训练模型10个epoch仍然需要大约8分钟。

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=4)

checkpoint = torch.load("ResidueTypePrediction_ECToy.pth")["model"]
checkpoint = {k: v for k, v in checkpoint.items() if not k.startswith("mlp")}
task.load_state_dict(checkpoint, strict=False)

solver.train(num_epoch=10)
solver.evaluate("valid")

我们观察到,对预训练模型进行微调优于从头开始训练。然而,这两种方案的性能都不尽如人意,主要原因是数据集规模过小。我们建议用户在更大的蛋白质结构数据集(例如,datasets.AlphaFoldDB)上执行预训练,以充分研究预训练的有效性。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值