如上篇文章所提及,Strategies for Pre-training Graph Neural Networks一文的作者提出了节点层面进行预训练的两种方法,分别是:Context 和 Attribute Prediction。这两种预训练方法可以让模型学会节点层面嵌入。将预训练得到的模型用于下游图层面的任务时,可以很好的保留节点层面的信息,活得更好的泛化能力。
接下来,这一部分就来具体介绍子结构预测 Context Prediction,并提供一个可以直接运行的代码版本。
一、 Context prediction预训练介绍
Context 预训练示意图如下:
首先,将分子的SMILES转化为图,随机选择一个节点作为中心节点V。
对于任意一个中心节点V,都可以找到一个K阶子图,意味着,在图神经网络的信息传递过程中,经过K次消息传递,K阶子图中的节点都可以受到中心节点的V的影响。
对于,r1阶和r2阶的子图(其中要求,r1<r2)中共享的节点称之为 context graph,代表距离中心节点V在r1阶和r2阶之间的子结构,即上图两个粉色虚线圈之间的节点组成的部分。
关于r1,r2和K之间的关系,文章要求r1<K<r2,在r1阶到K阶之间的节点称之为context anchor node,即粉色小圈到蓝色圈之间的节点,代表节点V的信息,如何从中心节点传递到“r2外面”的节点。
这里的context graph和context anchor node是很重要的两个概念。
Context prediction预训练的目的是:判断分子图中的节点是否属于中心节点的context anchor node。
具体来说,有两个GNN模型,分别是model1和model2,其中model1是主模型,在分子图上进行消息传递,最后输出每一个节点的嵌入向量。model2仅在context graph上进行消息传递,得到context graph图上节点的嵌入向量。从model1中提取中心节点的嵌入向量X(root),从model2中提取context anchor node节点的嵌入向量X(anchor)。随机负采样一些节点,作为伪中心节点,提取伪中心节点在model1中的嵌入向量X(root-false)。
自监督训练的目标:中心节点的嵌入向量X(root)与context anchor node节点的嵌入向量X(anchor)最大相似,并且中心节点的嵌入向量X(root)与伪中心节点的嵌入向量X(root-false)不相似。也就是:处于结构范围内的、化学环境类似的节点的嵌入相似。也就实现了model1对分子子结构的识别。
model1就是我们要预训练的模型。
总结,分一下几步:
- 生成分子全图 graph ;
- 根据r1,r2,K生成分子的K阶子图和context graph,并标记context anchor node ;
- 一比一负采样伪中心节点;
- 在训练过程中,判断中心节点和伪中心节点分别与context anchor node 节点的嵌入是否相似,使用negative log likelihood作为损失函数;
自监督训练目的:结构相近的节点嵌入相同,结构不相近的节点嵌入不同。
二、模型代码解析
下面进行代码解析部分。
该部分代码运行的环境可以详见Pyg_pretrain.yml文件。注意,我这里使用的torch_geometric的版本是最近的2.0.3,而文献使用的是1.0.3已经无法下载安装了。运行环境见第四部分。
2.1 导入相关的库
分别是pandas,networkx, torch以及torch_geometric,还有就是解析分子必须的Rdkit.
import pandas as pd
from tqdm import tqdm
import numpy as np
import networkx as nx
import os
import math
import random
import torch
import torch.optim as optim
torch.manual_seed(0)
np.random.seed(0)
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
import torch.nn.functional as F
from torch_geometric.data import Data #Pyg的图数据类型
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
torch.manual_seed(0)
np.random.seed(0)
2.2 数据预处理,加载数据
分子原子和边允许的数据类型如下:
# allowable node and edge features
#实际上只使用了原子的特征和边的特征仅用到了两个
allowable_features = {
'possible_atomic_num_list' : list(range(1, 119)), #元素种类
'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], #电荷特征
'possible_chirality_list' : [ #手性类型
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
],
'possible_hybridization_list' : [ #轨道杂化类型
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
],
'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8], #H的数量
'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6], #隐含价键数
'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # 度
'possible_bonds' : [ # 键的类型
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC
],
'possible_bond_dirs' : [ # only for double bond stereo information #键的方向
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
}
虽然,上述的节点(原子)和边(化学键)的允许的类型很多,但是文章中直接使用原子的序号和手性作为特征,而没有考虑节点的度,不饱和键数等作为特征。由于直接使用原子序号做嵌入,导致要嵌入到119维,因此,文章构造图的方法,可能是不合适的。
将rdkit的mol对象转化成Pyg的数据类型:节点特征保存在data.x,边的特征保存data.edge_attr,边保存在data.edge_index。其中,节点特征数为num_atom_features = 2,仅包括原子种类和手性。边的特征也为2,num_bond_features = 2,包括键的类型和方向。
def mol_to_graph_data_obj_simple(mol):
"""
将rdkit对象转化为Pyg对象
"""
# 节点,原子只有原子种类+手性标签
num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
atom_feature = [allowable_features['possible_atomic_num_list'].index(
atom.GetAtomicNum())] + [allowable_features[
'possible_chirality_list'].index(atom.GetChiralTag())]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# 边,键只有键的类型和键的方向
num_bond_features = 2 # bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = [allowable_features['possible_bonds'].index(
bond.GetBondType())] + [allowable_features[
'possible_bond_dirs'].index(
bond.GetBondDir())]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else: # mol has no bonds
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) # Data是Pyg的数据类型
return data
基于PYG的Dataset构造MoleculeDataset类,将分子由SMILES转化为PYG图,并保存成PYG的Dataset类型。这里默认加载的是zinc250k数据集,保存在dataset文件夹内。其中使用到上面的mol_to_graph_data_obj_simple函数,将分子转化为PYG的图,然后通过self.collate(data_list) 整合成为Dataset。
class MoleculeDataset(InMemoryDataset):
'''
将zinc数据集加载成PYG的Dataset
'''
def __init__(self, root, dataset='zinc250k',
transform=None, pre_transform=None,
pre_filter=None):
print(dataset)
self.dataset = dataset
self.root = root
super(MoleculeDataset, self).__init__(root, transform, pre_transform,
pre_filter) #要放在后面
self.data, self.slices = torch.load(self.processed_paths[0])
@property #返回原始文件列表
def raw_file_names(self):
file_name_list = os.listdir(self.raw_dir)
return file_name_list
@property #返回需要跳过的文件列表
def processed_file_names(self):
return 'geometric_data_processed.pt'
def process(self):
data_smiles_list = []
data_list = []
input_path = self.raw_paths[0]
input_df = pd.read_csv(input_path, sep=',', compression='gzip',
dtype='str')
smiles_list = list(input_df['smiles'])
zinc_id_list = list(input_df['zinc_id'])
for i in range(len(smiles_list)):
if i%1000==0:
print(str(i)+'...')
s = smiles_list[i]
try:
rdkit_mol = AllChem.MolFromSmiles(s, sanitize=True)
if rdkit_mol != None: # ignore invalid mol objects
# # convert aromatic bonds to double bonds
# Chem.SanitizeMol(rdkit_mol,sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
data = mol_to_graph_data_obj_simple(rdkit_mol)
# manually add mol id
id = int(zinc_id_list[i].split('ZINC')[1].lstrip('0'))
data.id = torch.tensor([id]) # id here is zinc id value, stripped of
# leading zeros
data_list.append(data)
data_smiles_list.append(smiles_list[i])
except:
continue
#过滤器
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
#转换器,
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
# write data_smiles_list in processed paths
data_smiles_series = pd.Series(data_smiles_list)
data_smiles_series.to_csv(os.path.join(self.processed_dir,
'content_smiles.csv'), index=False,
header=False)
# InMemoryDataset的方法,将 torch_geometric.data.Data的list,转化为内部存储
#这里设置的保存路径为processedpath[0]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
#显示属性
def __repr__(self):
return '{}()'.format(self.dataname)
正如在解析部分所述的,需要在分子图上随机抽取context grah。即,将一个分子的Pyg数据对象,添加context相关的特征,包括:k阶子图,及其索引;context图,及其索引; context anchor node节点,及其在context图中的索引;
处理方法:
分子 → mol对象 → Pyg图 → NX图 → k阶子图/context图 → context anchor node context → Pyg图特征
把上述过程写入ExtractSubstructureContextPair类中,每一个分子图添加了 :data.center_substruct_idx:中心节点
data.x_substruct:K阶子图节点特征
data.edge_attr_substruct:K阶子图边特征
data.edge_index_substruct:K阶子图边
data.x_context:context子图节点特征
data.edge_attr_context:context子图边特征
data.edge_index_context:context子图边
data.overlap_context_substruct_idx:context anchor node节点序号
class ExtractSubstructureContextPair:
def __init__(self, k, l1, l2):
"""
随机选择一个根节点,然后根据k,r1(l1), r2(l2)标记context graph
"""
self.k = k #文献中的k
self.l1 = l1 #文献中的r1
self.l2 = l2 #文献中的r2
#如果为0,标记为负数
# for the special case of 0, addresses the quirk with
# single_source_shortest_path_length
if self.k == 0:
self.k = -1
if self.l1 == 0:
self.l1 = -1
if self.l2 == 0:
self.l2 = -1
def __call__(self, data, root_idx=None):
"""
# Pyg数据对象中,添加如下特征:
data.center_substruct_idx
data.x_substruct
data.edge_attr_substruct
data.edge_index_substruct
data.x_context
data.edge_attr_context
data.edge_index_context
data.overlap_context_substruct_idx
"""
num_atoms = data.x.size()[0] #节点数
# 若无指定随机选择根节点
if root_idx == None:
root_idx = random.sample(range(num_atoms), 1)[0]
#Pyg数据转为nx对象
G = graph_data_obj_to_nx_simple(data) # same ordering as input data obj
# 根节点的K阶子图的节点索引
substruct_node_idxes = nx.single_source_shortest_path_length(G,
root_idx,
self.k).keys()
#如果有K阶子图
if len(substruct_node_idxes) > 0:
substruct_G = G.subgraph(substruct_node_idxes) #K阶子图
substruct_G, substruct_node_map = reset_idxes(substruct_G) #重置K阶子图的索引
substruct_data = nx_to_graph_data_obj_simple(substruct_G) #将K阶子图由Nx图转换为Pyg图
#将K阶子图特征添加到分子的PYG图中
data.x_substruct = substruct_data.x
data.edge_attr_substruct = substruct_data.edge_attr
data.edge_index_substruct = substruct_data.edge_index
#中心节点记住子图中的索引
data.center_substruct_idx = torch.tensor([substruct_node_map[
root_idx]])
# r1子图
l1_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l1).keys()
# r2子图
l2_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l2).keys()
# r1-r2之间的共享节点
context_node_idxes = set(l1_node_idxes).symmetric_difference(
set(l2_node_idxes))
# 如果存在r1-r2之间的共享节点,context graph
if len(context_node_idxes) > 0:
context_G = G.subgraph(context_node_idxes) # 提取context graph
context_G, context_node_map = reset_idxes(context_G) # 重置context graph索引
context_data = nx_to_graph_data_obj_simple(context_G) #context graph转为Pyg图类型
#context graph添加到分子图中,做为特征保存
data.x_context = context_data.x
data.edge_attr_context = context_data.edge_attr
data.edge_index_context = context_data.edge_index
# 获取context anchor node,context图和K阶子图的共享节点
context_substruct_overlap_idxes = list(set(
context_node_idxes).intersection(set(substruct_node_idxes))) #context anchor node索引
#如果存在 context anchor node
if len(context_substruct_overlap_idxes) > 0:
# 在context图中的记录context anchor node的索引
context_substruct_overlap_idxes_reorder = [context_node_map[old_idx]
for
old_idx in
context_substruct_overlap_idxes]
# need to convert the overlap node idxes, which is from the
# original graph node ordering to the new context node ordering
data.overlap_context_substruct_idx = \
torch.tensor(context_substruct_overlap_idxes_reorder)
return data
def __repr__(self):
return '{}(k={},l1={}, l2={})'.format(self.__class__.__name__, self.k,
self.l1, self.l2)
在ExtractSubstructureContextPair涉及到Pyg图转化为NX图,处理完以后,又转化会PYG图,所以有以下几个支持函数:
# 将NX对象转换为Pyg数据对象
def nx_to_graph_data_obj_simple(G):
"""
将NX图转换为Pyg数据对象,图
"""
# atoms,原子节点
num_atom_features = 2 # atom type, chirality tag
atom_features_list = []
for _, node in G.nodes(data=True):
atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']]
atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
# bonds,边
num_bond_features = 2 # bond type, bond direction
#如果有边
if len(G.edges()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for i, j, edge in G.edges(data=True):
edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']]
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_features_list),
dtype=torch.long)
else:
# 没有边
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data
# pyg图对象生成NX图对象
# 为了查找子图
def graph_data_obj_to_nx_simple(data):
# data是PYG图
G = nx.Graph()
# atoms
atom_features = data.x.cpu().numpy() #节点特征
num_atoms = atom_features.shape[0] #节点数量
#将节点特征添加到NX中
for i in range(num_atoms):
atomic_num_idx, chirality_tag_idx = atom_features[i]
G.add_node(i, atom_num_idx=atomic_num_idx, chirality_tag_idx=chirality_tag_idx)
pass
#将边特征添加到NX中
edge_index = data.edge_index.cpu().numpy()
edge_attr = data.edge_attr.cpu().numpy()
num_bonds = edge_index.shape[1]
for j in range(0, num_bonds, 2):
begin_idx = int(edge_index[0, j])
end_idx = int(edge_index[1, j])
bond_type_idx, bond_dir_idx = edge_attr[j] #边的特征
if not G.has_edge(begin_idx, end_idx):
#逐条添加边
G.add_edge(begin_idx, end_idx, bond_type_idx=bond_type_idx,
bond_dir_idx=bond_dir_idx)
return G
我们尝试使用MoleculeDataset构造数据集同时使用ExtractSubstructureContextPair类,提取每一个分子的context graph及其必要数据。这里使用ExtractSubstructureContextPair作为转化函数,提取分子的子结构等信息:
dataset = MoleculeDataset(root="dataset/zinc_standard_agent",
dataset='zinc_standard_agent',
transform = ExtractSubstructureContextPair(3, 2, 6))
dataset[0]
结果如下:
关于Dataloader,也要做一些特殊的处理。因为,分子中关于K阶子图和context图的内容,特别是中心节点的标记,都是放置在图层面的特征中的。由于pytorch默认的DataLoader是不处理的,直接简单的堆叠,图是直接拼接起来。而context图直接拼接起来,会丢失原来的索引,导致不知道anchor节点、中心节点的序号,所以,要做特殊处理。主要是针对子图来操作。如下BatchSubstructContext类:
# 批次加载数据
# 多个分子图(pyg图数据)生成PYG的大图,用于批次训练
#将多个PYG图组成的list,生成一个不连接的大图,用于批次训练
class BatchSubstructContext(Data):
"""
Specialized batching for substructure context pair!
"""
def __init__(self, batch=None, **kwargs):
super(BatchSubstructContext, self).__init__(**kwargs)
self.batch = batch
@staticmethod
def from_data_list(data_list):
#keys = [set(data.keys) for data in data_list]
#keys = list(set.union(*keys))
#assert 'batch' not in keys
batch = BatchSubstructContext() #自己引用自己的数据类
keys = ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct", "overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]
for key in keys:
#print(key)
batch[key] = []
#batch.batch = []
#used for pooling the context
#记录批次中的anchor node属于哪一图的索引
batch.batch_overlapped_context = [] #记录批次中anchor node,节点数量等于所有分子的anchor node之和,相当于索引
batch.overlapped_context_size = [] #记录批次中每一张anchor node的大小
cumsum_main = 0
cumsum_substruct = 0
cumsum_context = 0
i = 0
for data in data_list:
#If there is no context, just skip!!
# 如果没有context图,跳过!
if hasattr(data, "x_context"):
num_nodes = data.num_nodes
num_nodes_substruct = len(data.x_substruct) #k阶子图节点数
num_nodes_context = len(data.x_context) # context图节点数
#batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
#记录context anchor节点,索引
batch.batch_overlapped_context.append(torch.full((len(data.overlap_context_substruct_idx), ), i, dtype=torch.long))
batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx))
#批次中的K阶子图
for key in ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct"]:
item = data[key]
item = item + cumsum_substruct if batch.cumsum(key, item) else item
batch[key].append(item)
###batching for the context graph
#批次中的context图
for key in ["overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]:
item = data[key] #特征值
#key是特征键
#看批次中如果已经有了特征键,则新分子的特征值加在list后面,否则新建值
item = item + cumsum_context if batch.cumsum(key, item) else item
batch[key].append(item)
cumsum_main += num_nodes #总节点数
cumsum_substruct += num_nodes_substruct # K阶子图总节点数
cumsum_context += num_nodes_context # context图节点数
i += 1 #分子数
for key in keys:
batch[key] = torch.cat(
batch[key], dim=batch.cat_dim(key)) #每一个特征叠在一起,生成批次特征
#batch.batch = torch.cat(batch.batch, dim=-1)
batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1)
batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size)
return batch.contiguous()
def cat_dim(self, key):
#注意,边序号的叠加方式和特征的叠加方式不一样,因为边是(2,n), 而特征是(n, )
return -1 if key in ["edge_index", "edge_index_substruct", "edge_index_context"] else 0
def cumsum(self, key, item):
#查看某个特征键是不是在批次特征键中
return key in ["edge_index", "edge_index_substruct", "edge_index_context", "overlap_context_substruct_idx", "center_substruct_idx"]
@property
def num_graphs(self):
"""Returns the number of graphs in the batch."""
return self.batch[-1].item() + 1
# context预训练批次的dataloader
class DataLoaderSubstructContext(torch.utils.data.DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoaderSubstructContext, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list),
**kwargs)
2.3 GIN模型
文献可能是比较老,使用的是非常简单的GIN的模型,还自己写了GIN层,现在可以直接从pytorch.nn.models中直接调用的,输出的是图中节点的嵌入向量。
class GINConv(MessagePassing):
"""
文献中的GIN模型
"""
def __init__(self, emb_dim, num_bond_type=5, num_bond_direction=3, aggr = "add"):
super(GINConv, self).__init__(aggr = "add")
#multi-layer perceptron
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
self.aggr = aggr
def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))[0]
edge_index = edge_index.long()
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_embeddings)
def message(self, x_j, edge_attr):
return x_j + edge_attr
def update(self, aggr_out):
return self.mlp(aggr_out)
class GNN(torch.nn.Module):
"""
GIN模型
"""
def __init__(self, num_layer, emb_dim, num_atom_type=120, num_chirality_tag=4, JK = "last", drop_ratio = 0.5):
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
###List of MLPs
self.gnns = torch.nn.ModuleList()
for layer in range(num_layer):
self.gnns.append(GINConv(emb_dim, aggr = "add"))
###List of batchnorms
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layer):
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
#def forward(self, x, edge_index, edge_attr):
def forward(self, *argv):
if len(argv) == 3:
x, edge_index, edge_attr = argv[0], argv[1], argv[2]
elif len(argv) == 1:
data = argv[0]
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
else:
raise ValueError("unmatched number of arguments.")
x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
h_list = [x]
for layer in range(self.num_layer):
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
#h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "concat":
node_representation = torch.cat(h_list, dim = 1)
elif self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "max":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
elif self.JK == "sum":
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]
return node_representation
根据文章思路,有俩个模型,分别训练在大图上,还有一个训练在context图上。两个模型的定义仅仅在层数的差别:
#主模型,在分子图上训练
model = GNN(7,512)
model = model.to(device)
#context模型,在context图上训练
context_model = GNN(4,512)
context_model = context_model.to(device)
2.4 训练过程
首先写一个单次训练函数train。
注意:substruct_rep是分子图上中心节点的嵌入向量;overlapped_node_rep是context图中context anchor node的嵌入向量。我们希望二者之间相似,所以他们是pos正样本。而expanded_substruct_rep 是在批次中,随机抽取的负样本的嵌入向量,类似于context anchor node,但是我们希望这些嵌入向量与中心节点的嵌入向量不相似。这样子神经网络就学会了相似的子结构范围内的嵌入向量相似。当然,模型的结构要求类似,所以分子图上的主模型和context图上的模型只有层数的差别。
其后会看到,损失函数使用的是torch.nn.BCEWithLogitsLoss(),交叉熵。
#单次epoch训练
def train(model, context_model, loader, optimizer_substruct, optimizer_context, criterion, device):
#
# 这里,我们使用skipgram的方法,即用中心节点来预测周围的节点,
# 在skip-gram当中,每个节点都要收到周围的节点的影响,每个节点在作为中心节点的时候,都要进行K次的预测、调整。
# 因此, 当数据量较少,或者节点为生僻节点出现次数较少时,
# 这种多次的调整会使得节点向量相对的更加准确。
#
model.train()
context_model.train()
balanced_loss_accum = 0
acc_accum = 0
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
# k阶子图的根节点的表示
substruct_rep = model(batch.x_substruct,
batch.edge_index_substruct,
batch.edge_attr_substruct)[batch.center_substruct_idx]
# context子图的context anchor节点的表示
overlapped_node_rep = context_model(batch.x_context,
batch.edge_index_context,
batch.edge_attr_context)[batch.overlap_context_substruct_idx]
# 有几个根节点,就提取几个根节点周围的节点,简称周围节点
expanded_substruct_rep = torch.cat([substruct_rep[i].repeat((batch.overlapped_context_size[i],1))
for i in range(len(substruct_rep))], dim = 0)
# 根节点与周围节点做内积, 相似度
pred_pos = torch.sum(expanded_substruct_rep * overlapped_node_rep, dim = 1)
#负样本的表示
shifted_expanded_substruct_rep = []
neg_samples = 1 #负样本比例
for i in range(neg_samples):
#取出与根节点一样的周围节点,带shift的形式,负样本
shifted_substruct_rep = substruct_rep[cycle_index(len(substruct_rep), i+1)]
#负样本的节点表示,放在一起
shifted_expanded_substruct_rep.append(torch.cat([shifted_substruct_rep[i].repeat(
(batch.overlapped_context_size[i],1)) for i in range(len(shifted_substruct_rep))], dim = 0))
#所有的负样本放在一起,Tensor
shifted_expanded_substruct_rep = torch.cat(shifted_expanded_substruct_rep, dim = 0)
# 根节点与负样本节点的相似度
pred_neg = torch.sum(shifted_expanded_substruct_rep * overlapped_node_rep.repeat((neg_samples, 1)), dim = 1)
#正\负样本损失
loss_pos = criterion(pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double())
loss_neg = criterion(pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double())
#优化器梯度归零
optimizer_substruct.zero_grad()
optimizer_context.zero_grad()
#损失相加
loss = loss_pos + neg_samples*loss_neg
#生成梯度
loss.backward()
#优化器更新参数
optimizer_substruct.step()
optimizer_context.step()
#损失
balanced_loss_accum += float(loss_pos.detach().cpu().item() + loss_neg.detach().cpu().item())
#精度
acc_accum += 0.5* (float(torch.sum(pred_pos > 0).detach().cpu().item())/len(pred_pos)
+ float(torch.sum(pred_neg < 0).detach().cpu().item())/len(pred_neg))
#返回该批次的精度
return balanced_loss_accum/step, acc_accum/step
实际训练过程代码:
if __name__ =='__main__':
#运行设备
device = torch.device(torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu'))
#主模型
model = GNN(7,512)
model = model.to(device)
#conetxt模型
context_model = GNN(4,512)
context_model = context_model.to(device)
#加载数据
dataset = MoleculeDataset(
root="dataset/zinc_standard_agent",
dataset='zinc_standard_agent',
transform = ExtractSubstructureContextPair(3, 2, 6))
loader = DataLoaderSubstructContext(dataset, batch_size=512, shuffle=True, num_workers=8, pin_memory=True)
# batch_size=1024 大概是16G的内存
# 优化器
optimizer_substruct = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
optimizer_context = optim.Adam(context_model.parameters(), lr=0.001, weight_decay=0.0001)
# 损失函数
criterion = torch.nn.BCEWithLogitsLoss()
# 逐个epoch训练
#迭代次数
epcohs = 100
#记录
log_loss = []
log_acc = []
model = model.to(device)
context_model = context_model.to(device)
for epoch in range(epcohs):
print("====epoch " + str(epoch))
train_loss, train_acc = train(model, context_model, loader, optimizer_substruct, optimizer_context, criterion, device)
log_loss.append(train_loss)
log_acc.append(train_acc)
# 保存损失和指标
np.save("Context_Pretrain__loss.npy", log_loss)
np.save("Context_Pretrain_log_acc.npy", log_acc)
print('Epoch:{},loss:{}, acc:{}'.format(
epoch, train_loss, train_acc))
# 保存模型,由于模型训练时长很长,所以每次都要保存一下
torch.save(model.state_dict(), "Context_Pretrain_GIN_para.pth")
torch.save(model, "Context_Pretrain_GIN.pth")
torch.save(context_model, 'context_model.pth')
三、模型训练结果
经过训练,我们得到一个与训练好的GIN模型,保存在Context_Pretrain_GIN.pth中。
训练时间比较长,使用V100的显卡,每次迭代需要15分钟,100个循环大概需要三天。训练过程的损失函数如下:
四、运行环境
工作目录如下:
.
├── Context_Pretrain_GIN.pth
├── Context_Pretrain_GIN_para.pth
├── Context_Pretrain__loss.npy
├── Context_Pretrain_log_acc.npy
├── Pyg_pretrain.yml #运行环境
├── context_model.pth
├── context_pretrain.py
├── dataset #数据集
├── pretrain_context_预训练损失函数.ipynb
└── 说明.text
运行环境查看Pyg_pretrain.yml文件。上文的代码是缺少必要的支持函数的,需要直接运行的,需要下载代码包。
代码及数据的下载链接:
链接:https://pan.baidu.com/s/14cxHjU2zwzkqPfwwfuSx0Q
提取码:795y