TorchDrug教程--预训练的分子表示

TorchDrug教程–预训练的分子表示

教程来源TorchDrug开源

目录

在许多药物发现任务中,收集标记数据在时间和金钱上都是昂贵的。作为一种解决方案,引入了自监督预训练来从大量未标记的数据中学习分子表示。

在本教程中,我们将演示如何在分子上预训练图神经网络,以及如何在下游任务上微调模型。

自我监督预训练

预训练是在图神经网络中进行图级属性预测的一种有效的迁移学习方法。在这里,我们专注于通过不同的自我监督策略预训练GNNs。这些方法通常基于分子的结构信息构建无监督损失函数。

为了说明原因,我们在本教程中只使用ClinTox数据集,它比标准的预训练数据集要小得多。

Infograph

InfoGraph (IG)建议最大化图级和节点级表示之间的互信息。它通过区分节点图对是来自单个图还是来自两个不同的图来学习模型。下图展示了InfoGraph的高级概念。

我们使用GIN作为我们的图形表示模型,并用InfoGraph包装它。

import torch
from torch import nn
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
                           bond_feature="pretrain")

gin_model = models.GIN(input_dim=dataset.node_feature_dim,
                       hidden_dims=[300, 300, 300, 300, 300],
                       edge_input_dim=dataset.edge_feature_dim,
                       batch_norm=True, readout="mean")
model = models.InfoGraph(gin_model, separate_model=False)

task = tasks.Unsupervised(model)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=100)
solver.save("clintox_gin_infograph.pth")

经过训练,表示的相互信息可能接近

average graph-node mutual information: 1.30658

Attribute Masking

属性masking的目的是通过学习分布在图结构上的节点/边属性的规律来获取领域知识。高层次的思想是通过随机掩盖的节点特征来预测分子图中的原子类型。

同样,我们使用GIN作为我们的图表示模型。

import torch
from torch import nn, optim
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("~/molecule-datasets/", atom_feature="pretrain",
                           bond_feature="pretrain")

model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[300, 300, 300, 300, 300],
                   edge_input_dim=dataset.edge_feature_dim,
                   batch_norm=True, readout="mean")
task = tasks.AttributeMasking(model, mask_rate=0.15)

optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=100)
solver.save("clintox_gin_attributemasking.pth")

通常,训练精度和交叉熵看起来如下所示。

average accuracy: 0.920366
average cross entropy: 0.22998

除了InfoGraph和Attribute Masking, gnn的预训练还有一些其他的策略。有关详细信息,请参阅下面的文档。

InfoGraph, AttributeMasking, EdgePrediction, ContextPrediction

关于标记数据集的Finetune

当GNN预训练完成后,我们可以在下游任务上对预训练的GNN模型进行微调。这里我们使用BACE数据集进行说明,该数据集包含1513个具有结合亲和力的人β-分泌酶1(BACE-1)抑制剂分子。

首先,我们下载BACE数据集,并将其分为训练集、验证集和测试集。注意,我们需要将数据集中的节点和边缘特征设置为预训练,以使其与预训练的模型兼容。

dataset = datasets.BACE("~/molecule-datasets/",
                        atom_feature="pretrain", bond_feature="pretrain")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = data.ordered_scaffold_split(dataset, lengths)

然后,我们定义与预训练阶段相同的模型,并为我们的下游任务设置优化器和求解器。这里唯一的区别是我们使用PropertyPrediction任务来支持监督学习。

model = models.GIN(input_dim=dataset.node_feature_dim,
                hidden_dims=[300, 300, 300, 300, 300],
                edge_input_dim=dataset.edge_feature_dim,
                batch_norm=True, readout="mean")
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="bce", metric=("auprc", "auroc"))

optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=256)

现在我们可以加载预训练的模型,并在下游数据集上对其进行微调。

checkpoint = torch.load("clintox_gin_attributemasking.pth")["model"]
task.load_state_dict(checkpoint, strict=False)

solver.train(num_epoch=100)
solver.evaluate("valid")

一旦模型训练好了,我们就在验证集上评估它。结果可能类似于下面的情况。

auprc [Class]: 0.921956
auroc [Class]: 0.663004
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值