GraphAF源码解读

GraphAF源码解读

代码来之torchdrug

执行代码

import torch
from torchdrug import datasets
from torch import nn, optim
from torchdrug import core, models, tasks
from torchdrug.layers import distribution
## 加载数据
dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True, atom_feature="symbol")

model = models.RGCN(input_dim=dataset.num_atom_type,
                    num_relation=dataset.num_bond_type,
                    hidden_dims=[256, 256, 256], batch_norm=True)

num_atom_type = dataset.num_atom_type
# add one class for non-edge
num_bond_type = dataset.num_bond_type + 1 ## 添加一条不能连接的边

node_prior = distribution.IndependentGaussian(torch.zeros(num_atom_type),
                                              torch.ones(num_atom_type)) ## 节点的高斯分布
edge_prior = distribution.IndependentGaussian(torch.zeros(num_bond_type),
                                              torch.ones(num_bond_type))
node_flow = models.GraphAF(model, node_prior, num_layer=12)
edge_flow = models.GraphAF(model, edge_prior, use_edge=True, num_layer=12)

task = tasks.AutoregressiveGeneration(node_flow, edge_flow, max_node=38, max_edge_unroll=12, criterion="nll")

optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, dataset, None, None, optimizer,
                     gpus=(1,), batch_size=64, log_interval=10)

solver.train(num_epoch=1)
solver.save("drug_examples/graphgeneration/graphaf_zinc250k_10epoch.pkl")

solver.load("drug_examples/graphgeneration/graphaf_zinc250k_10epoch.pkl")
results = task.generate(num_sample=32)
print(results.to_smiles())

训练数据处理源码

## 训练入口 generation.py --> class AutoregressiveGeneration
    def forward(self, batch):
        """"""
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}

        for criterion, weight in self.criterion.items():
            if criterion == "nll":
                _loss, _metric = self.density_estimation_forward(batch)
                all_loss += _loss * weight
                metric.update(_metric)
            elif criterion == "ppo":
                _loss, _metric = self.reinforce_forward(batch)
                all_loss += _loss * weight
                metric.update(_metric)
            else:
                raise ValueError("Unknown criterion `%s`" % criterion)

        return all_loss, metric


def density_estimation_forward(self, batch):
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}

        graph = batch["graph"]
        masked_graph, node_target = self.mask_node(graph, metric) ## 节点数据处理
        log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric)
        log_likelihood = log_likelihood.mean()
        metric["node log likelihood"] = log_likelihood
        all_loss += -log_likelihood

        masked_graph, edge_target, edge = self.mask_edge(graph, metric)## 边数据处理
        log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric)
        log_likelihood = log_likelihood.mean()
        metric["edge log likelihood"] = log_likelihood
        all_loss += -log_likelihood

        return all_loss, metric

def all_node(self, graph):
	starts, ends, valid = self._all_prefix_slice(graph.num_nodes) # 图中每个分子的个数

	num_repeat = len(starts) // len(graph) # num_repeat就是每个batch的最大原子数量
    graph = graph.repeat(num_repeat)
	mask = functional.multi_slice_mask(starts, ends, graph.num_node)
	new_graph = graph.subgraph(mask)
	target = graph.subgraph(ends).atom_type # 节点的类型

	return new_graph[valid], target[valid]

## 边预处理函数
def all_edge(self, graph):
        if (graph.num_nodes < 2).any():
            graph = graph[graph.num_nodes >= 2]
            warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped")

        lengths = self._valid_edge_prefix_lengths(graph)

        starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths)

        num_keep_dense_edges = ends - starts# edge id (max_node_id x max_node_id 的矩阵编号, 一个分子)
        num_repeat = len(starts) // len(graph)
        graph = graph.repeat(num_repeat)# 复制num个分子,重置graph

        # undirected: all upper triangular edge ids are flipped to lower triangular ids 无向:所有上三角形边id都翻转到下三角形id
        # 1 -> 2, 4 -> 6, 5 -> 7
        node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) #   原来分子的边索引
        node_in, node_out = node_index.t()
        node_large = node_index.max(dim=-1)[0] # 每条边的最大索引值
        node_small = node_index.min(dim=-1)[0] # 每条边的最小索引值
        ## 下面三段不能理解????
        edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small # (node_in >= node_out) * node_large 找到进节点大于出节点的索引
        undirected_edge_id = node_large * (node_large + 1) + node_small #下三角edge_id
        
        edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] # num_keep_dense_edges 每个subgraph有多少条边 graph.edge2graph 每条边属于哪个subgraph
        circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long()
        starts = graph.num_cum_nodes - graph.num_nodes
        ends = starts + circum_box_size
        node_mask = functional.multi_slice_mask(starts, ends, graph.num_node)
        # compact nodes so that succeeding nodes won't affect graph pooling 压缩节点,以便后续节点不会影响图池
        new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True)

        positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] # 有边的edge_id
        positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool()# 在每个subgraph是否有边
        # default: non-edge (self.num_bond_type - 1)是没有边
        target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device)
        target[positive_graph] = graph.edge_list[positive_edge, 2] ##positive_edge的边类型对应到相应位置subgraph的类型

        node_in = circum_box_size - 1
        node_out = num_keep_dense_edges - node_in * circum_box_size
        edge = torch.stack([node_in, node_out], dim=-1)

        return new_graph[valid], target[valid], edge[valid]

模型

## RGCN做分子表征
class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
##高斯分布做采样
class IndependentGaussian(nn.Module):
....
# flow做分子生成
class GraphAutoregressiveFlow(nn.Module, core.Configurable):
.....
....

预测

def generate(self, num_sample, max_resample=20, off_policy=False, early_stop=False, verbose=0): # num_sample 需要采样的个数
        num_relation = self.num_bond_type - 1# 键类型个数
        is_training = self.training
        self.eval()

        if off_policy:
            node_model = self.agent_node_model
            edge_model = self.agent_edge_model
        else:
            node_model = self.node_model
            edge_model = self.edge_model
        ###-------------------------  前置数据内容 包括节点个数,边个数, 边类型列表, 节点类型列表,组成一个空的graph图, 空的子图有batch_size个-----------------
        edge_list = torch.zeros(0, 3, dtype=torch.long, device=self.device) # 边类型集合
        num_nodes = torch.zeros(num_sample, dtype=torch.long, device=self.device)# 节点的个数
        num_edges = torch.zeros_like(num_nodes)# 边的个数
        atom_type = torch.zeros(0, dtype=torch.long, device=self.device) # 原子类型
        graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
                                    num_relation=num_relation)
        completed = torch.zeros(num_sample, dtype=torch.bool, device=self.device)# 每个子图是否完成

        for node_in in range(self.max_node):# z最多分子数量
            atom_pred = node_model.sample(graph)
            # why we add atom_pred even if it is completed?  为什么要添加atom_pred,即使它已完成?
            # because we need to batch edge model over (node_in, node_out), even on completed graphs
            atom_type, num_nodes = self._append(atom_type, num_nodes, atom_pred)# 原来atom_type,num_nodes个数, 预测atom_pred
            graph = node_graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
                                                     num_relation=num_relation)

            start = max(0, node_in - self.max_edge_unroll)
            for node_out in range(start, node_in):
                is_valid = completed.clone()# 复制前面是否完成
                edge = torch.tensor([node_in, node_out], device=self.device).repeat(num_sample, 1)
                # default: non-edge
                bond_pred = (self.num_bond_type - 1) * torch.ones(num_sample, dtype=torch.long, device=self.device)# 没有键
                for i in range(max_resample):# 反复采样次数
                    # only resample invalid graphs
                    mask = ~is_valid#是否完成反面, 没有完成的为False
                    bond_pred[mask] = edge_model.sample(graph, edge)[mask]
                    # check valency 核对电荷
                    mask = (bond_pred < edge_model.input_dim - 1) & ~completed # (bond_pred < edge_model.input_dim - 1)键小于?, 并且未完成
                    edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1)# 边与键合并
                    tmp_edge_list, tmp_num_edges = self._append(edge_list, num_edges, edge_pred, mask)#无相边有两个方向 正向
                    edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1)
                    tmp_edge_list, tmp_num_edges = self._append(tmp_edge_list, tmp_num_edges, edge_pred, mask) # 反向
                    tmp_graph = data.PackedMolecule(tmp_edge_list, self.id2atom[atom_type], tmp_edge_list[:, -1],
                                                    num_nodes, tmp_num_edges, num_relation=num_relation)

                    is_valid = tmp_graph.is_valid | completed#边链接或者分子生成完成

                    if is_valid.all():# 如果所有的都是可以连接的
                        break

                if not is_valid.all() and verbose:
                    num_invalid = num_sample - is_valid.sum().item()
                    num_working = num_sample - completed.sum().item()
                    logger.warning("edge (%d, %d): %d / %d molecules are invalid even after %d resampling" %
                                   (node_in, node_out, num_invalid, num_working, max_resample))
                ## ----------------- 计算出atom_type和edge  -------------------- ##
                mask = (bond_pred < edge_model.input_dim - 1) & ~completed# 需要mask的边
                edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1)
                edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask)
                edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) # edge.flip(-1)按照维度对输入进行翻转
                edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask)# 无相边需要反方向计算两次
                graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
                                            num_relation=num_relation)

            if node_in > 0:
                assert (graph.num_edges[completed] == node_graph.num_edges[completed]).all()
                completed |= graph.num_edges == node_graph.num_edges #graph.num_edges == node_graph.num_edges 边预测后的数量是否改变
                if early_stop:
                    graph.atom_type = self.id2atom[graph.atom_type]
                    completed |= ~graph.is_valid
                    graph.atom_type = self.atom2id[graph.atom_type]
                if completed.all():
                    break

        self.train(is_training)

        # remove isolated atoms 移除孤立的原子
        index = graph.degree_out > 0
        # keep at least the first atom for each graph 至少为每个图保留第一个原子
        index[graph.num_cum_nodes - graph.num_nodes] = 1
        graph = graph.subgraph(index)
        graph.atom_type = self.id2atom[graph.atom_type]

        graph = graph[graph.is_valid_rdkit]
        return graph
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值