GraphAF源码解读
代码来之torchdrug
执行代码
import torch
from torchdrug import datasets
from torch import nn, optim
from torchdrug import core, models, tasks
from torchdrug.layers import distribution
## 加载数据
dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True, atom_feature="symbol")
model = models.RGCN(input_dim=dataset.num_atom_type,
num_relation=dataset.num_bond_type,
hidden_dims=[256, 256, 256], batch_norm=True)
num_atom_type = dataset.num_atom_type
# add one class for non-edge
num_bond_type = dataset.num_bond_type + 1 ## 添加一条不能连接的边
node_prior = distribution.IndependentGaussian(torch.zeros(num_atom_type),
torch.ones(num_atom_type)) ## 节点的高斯分布
edge_prior = distribution.IndependentGaussian(torch.zeros(num_bond_type),
torch.ones(num_bond_type))
node_flow = models.GraphAF(model, node_prior, num_layer=12)
edge_flow = models.GraphAF(model, edge_prior, use_edge=True, num_layer=12)
task = tasks.AutoregressiveGeneration(node_flow, edge_flow, max_node=38, max_edge_unroll=12, criterion="nll")
optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, dataset, None, None, optimizer,
gpus=(1,), batch_size=64, log_interval=10)
solver.train(num_epoch=1)
solver.save("drug_examples/graphgeneration/graphaf_zinc250k_10epoch.pkl")
solver.load("drug_examples/graphgeneration/graphaf_zinc250k_10epoch.pkl")
results = task.generate(num_sample=32)
print(results.to_smiles())
训练数据处理源码
## 训练入口 generation.py --> class AutoregressiveGeneration
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
for criterion, weight in self.criterion.items():
if criterion == "nll":
_loss, _metric = self.density_estimation_forward(batch)
all_loss += _loss * weight
metric.update(_metric)
elif criterion == "ppo":
_loss, _metric = self.reinforce_forward(batch)
all_loss += _loss * weight
metric.update(_metric)
else:
raise ValueError("Unknown criterion `%s`" % criterion)
return all_loss, metric
def density_estimation_forward(self, batch):
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
graph = batch["graph"]
masked_graph, node_target = self.mask_node(graph, metric) ## 节点数据处理
log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric)
log_likelihood = log_likelihood.mean()
metric["node log likelihood"] = log_likelihood
all_loss += -log_likelihood
masked_graph, edge_target, edge = self.mask_edge(graph, metric)## 边数据处理
log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric)
log_likelihood = log_likelihood.mean()
metric["edge log likelihood"] = log_likelihood
all_loss += -log_likelihood
return all_loss, metric
def all_node(self, graph):
starts, ends, valid = self._all_prefix_slice(graph.num_nodes) # 图中每个分子的个数
num_repeat = len(starts) // len(graph) # num_repeat就是每个batch的最大原子数量
graph = graph.repeat(num_repeat)
mask = functional.multi_slice_mask(starts, ends, graph.num_node)
new_graph = graph.subgraph(mask)
target = graph.subgraph(ends).atom_type # 节点的类型
return new_graph[valid], target[valid]
## 边预处理函数
def all_edge(self, graph):
if (graph.num_nodes < 2).any():
graph = graph[graph.num_nodes >= 2]
warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped")
lengths = self._valid_edge_prefix_lengths(graph)
starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths)
num_keep_dense_edges = ends - starts# edge id (max_node_id x max_node_id 的矩阵编号, 一个分子)
num_repeat = len(starts) // len(graph)
graph = graph.repeat(num_repeat)# 复制num个分子,重置graph
# undirected: all upper triangular edge ids are flipped to lower triangular ids 无向:所有上三角形边id都翻转到下三角形id
# 1 -> 2, 4 -> 6, 5 -> 7
node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) # 原来分子的边索引
node_in, node_out = node_index.t()
node_large = node_index.max(dim=-1)[0] # 每条边的最大索引值
node_small = node_index.min(dim=-1)[0] # 每条边的最小索引值
## 下面三段不能理解????
edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small # (node_in >= node_out) * node_large 找到进节点大于出节点的索引
undirected_edge_id = node_large * (node_large + 1) + node_small #下三角edge_id
edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] # num_keep_dense_edges 每个subgraph有多少条边 graph.edge2graph 每条边属于哪个subgraph
circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long()
starts = graph.num_cum_nodes - graph.num_nodes
ends = starts + circum_box_size
node_mask = functional.multi_slice_mask(starts, ends, graph.num_node)
# compact nodes so that succeeding nodes won't affect graph pooling 压缩节点,以便后续节点不会影响图池
new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True)
positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] # 有边的edge_id
positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool()# 在每个subgraph是否有边
# default: non-edge (self.num_bond_type - 1)是没有边
target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device)
target[positive_graph] = graph.edge_list[positive_edge, 2] ##positive_edge的边类型对应到相应位置subgraph的类型
node_in = circum_box_size - 1
node_out = num_keep_dense_edges - node_in * circum_box_size
edge = torch.stack([node_in, node_out], dim=-1)
return new_graph[valid], target[valid], edge[valid]
模型
## RGCN做分子表征
class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
##高斯分布做采样
class IndependentGaussian(nn.Module):
....
# flow做分子生成
class GraphAutoregressiveFlow(nn.Module, core.Configurable):
.....
....
预测
def generate(self, num_sample, max_resample=20, off_policy=False, early_stop=False, verbose=0): # num_sample 需要采样的个数
num_relation = self.num_bond_type - 1# 键类型个数
is_training = self.training
self.eval()
if off_policy:
node_model = self.agent_node_model
edge_model = self.agent_edge_model
else:
node_model = self.node_model
edge_model = self.edge_model
###------------------------- 前置数据内容 包括节点个数,边个数, 边类型列表, 节点类型列表,组成一个空的graph图, 空的子图有batch_size个-----------------
edge_list = torch.zeros(0, 3, dtype=torch.long, device=self.device) # 边类型集合
num_nodes = torch.zeros(num_sample, dtype=torch.long, device=self.device)# 节点的个数
num_edges = torch.zeros_like(num_nodes)# 边的个数
atom_type = torch.zeros(0, dtype=torch.long, device=self.device) # 原子类型
graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
num_relation=num_relation)
completed = torch.zeros(num_sample, dtype=torch.bool, device=self.device)# 每个子图是否完成
for node_in in range(self.max_node):# z最多分子数量
atom_pred = node_model.sample(graph)
# why we add atom_pred even if it is completed? 为什么要添加atom_pred,即使它已完成?
# because we need to batch edge model over (node_in, node_out), even on completed graphs
atom_type, num_nodes = self._append(atom_type, num_nodes, atom_pred)# 原来atom_type,num_nodes个数, 预测atom_pred
graph = node_graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
num_relation=num_relation)
start = max(0, node_in - self.max_edge_unroll)
for node_out in range(start, node_in):
is_valid = completed.clone()# 复制前面是否完成
edge = torch.tensor([node_in, node_out], device=self.device).repeat(num_sample, 1)
# default: non-edge
bond_pred = (self.num_bond_type - 1) * torch.ones(num_sample, dtype=torch.long, device=self.device)# 没有键
for i in range(max_resample):# 反复采样次数
# only resample invalid graphs
mask = ~is_valid#是否完成反面, 没有完成的为False
bond_pred[mask] = edge_model.sample(graph, edge)[mask]
# check valency 核对电荷
mask = (bond_pred < edge_model.input_dim - 1) & ~completed # (bond_pred < edge_model.input_dim - 1)键小于?, 并且未完成
edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1)# 边与键合并
tmp_edge_list, tmp_num_edges = self._append(edge_list, num_edges, edge_pred, mask)#无相边有两个方向 正向
edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1)
tmp_edge_list, tmp_num_edges = self._append(tmp_edge_list, tmp_num_edges, edge_pred, mask) # 反向
tmp_graph = data.PackedMolecule(tmp_edge_list, self.id2atom[atom_type], tmp_edge_list[:, -1],
num_nodes, tmp_num_edges, num_relation=num_relation)
is_valid = tmp_graph.is_valid | completed#边链接或者分子生成完成
if is_valid.all():# 如果所有的都是可以连接的
break
if not is_valid.all() and verbose:
num_invalid = num_sample - is_valid.sum().item()
num_working = num_sample - completed.sum().item()
logger.warning("edge (%d, %d): %d / %d molecules are invalid even after %d resampling" %
(node_in, node_out, num_invalid, num_working, max_resample))
## ----------------- 计算出atom_type和edge -------------------- ##
mask = (bond_pred < edge_model.input_dim - 1) & ~completed# 需要mask的边
edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1)
edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask)
edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) # edge.flip(-1)按照维度对输入进行翻转
edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask)# 无相边需要反方向计算两次
graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges,
num_relation=num_relation)
if node_in > 0:
assert (graph.num_edges[completed] == node_graph.num_edges[completed]).all()
completed |= graph.num_edges == node_graph.num_edges #graph.num_edges == node_graph.num_edges 边预测后的数量是否改变
if early_stop:
graph.atom_type = self.id2atom[graph.atom_type]
completed |= ~graph.is_valid
graph.atom_type = self.atom2id[graph.atom_type]
if completed.all():
break
self.train(is_training)
# remove isolated atoms 移除孤立的原子
index = graph.degree_out > 0
# keep at least the first atom for each graph 至少为每个图保留第一个原子
index[graph.num_cum_nodes - graph.num_nodes] = 1
graph = graph.subgraph(index)
graph.atom_type = self.id2atom[graph.atom_type]
graph = graph[graph.is_valid_rdkit]
return graph