图神经网络预训练 (2) - 子结构预测 Context Prediction 代码

如上篇文章所提及,Strategies for Pre-training Graph Neural Networks一文的作者提出了节点层面进行预训练的两种方法,分别是:ContextAttribute Prediction。这两种预训练方法可以让模型学会节点层面嵌入。将预训练得到的模型用于下游图层面的任务时,可以很好的保留节点层面的信息,活得更好的泛化能力。

接下来,这一部分就来具体介绍子结构预测 Context Prediction,并提供一个可以直接运行的代码版本。

一、 Context prediction预训练介绍

Context 预训练示意图如下:

首先,将分子的SMILES转化为图,随机选择一个节点作为中心节点V。

对于任意一个中心节点V,都可以找到一个K阶子图,意味着,在图神经网络的信息传递过程中,经过K次消息传递,K阶子图中的节点都可以受到中心节点的V的影响。

对于,r1阶和r2阶的子图(其中要求,r1<r2)中共享的节点称之为 context graph,代表距离中心节点V在r1阶和r2阶之间的子结构,即上图两个粉色虚线圈之间的节点组成的部分。

关于r1,r2和K之间的关系,文章要求r1<K<r2,在r1阶到K阶之间的节点称之为context anchor node,即粉色小圈到蓝色圈之间的节点,代表节点V的信息,如何从中心节点传递到“r2外面”的节点。

这里的context graphcontext anchor node是很重要的两个概念。

Context prediction预训练的目的是:判断分子图中的节点是否属于中心节点的context anchor node

具体来说,有两个GNN模型,分别是model1和model2,其中model1是主模型,在分子图上进行消息传递,最后输出每一个节点的嵌入向量。model2仅在context graph上进行消息传递,得到context graph图上节点的嵌入向量。从model1中提取中心节点的嵌入向量X(root),从model2中提取context anchor node节点的嵌入向量X(anchor)。随机负采样一些节点,作为伪中心节点,提取伪中心节点在model1中的嵌入向量X(root-false)。

自监督训练的目标:中心节点的嵌入向量X(root)与context anchor node节点的嵌入向量X(anchor)最大相似,并且中心节点的嵌入向量X(root)与伪中心节点的嵌入向量X(root-false)不相似。也就是:处于结构范围内的、化学环境类似的节点的嵌入相似。也就实现了model1对分子子结构的识别

model1就是我们要预训练的模型。

总结,分一下几步:

  1. 生成分子全图 graph ;
  2. 根据r1,r2,K生成分子的K阶子图和context graph,并标记context anchor node ;
  3. 一比一负采样伪中心节点;
  4. 在训练过程中,判断中心节点和伪中心节点分别与context anchor node 节点的嵌入是否相似,使用negative log likelihood作为损失函数;

自监督训练目的:结构相近的节点嵌入相同,结构不相近的节点嵌入不同。

二、模型代码解析

下面进行代码解析部分。

该部分代码运行的环境可以详见Pyg_pretrain.yml文件。注意,我这里使用的torch_geometric的版本是最近的2.0.3,而文献使用的是1.0.3已经无法下载安装了。运行环境见第四部分。

2.1 导入相关的库

分别是pandas,networkx, torch以及torch_geometric,还有就是解析分子必须的Rdkit.

import pandas as pd
from tqdm import tqdm
import numpy as np
import networkx as nx

import os
import math
import random
import torch
import torch.optim as optim
torch.manual_seed(0)
np.random.seed(0)

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect

import torch.nn.functional as F
from torch_geometric.data import Data #Pyg的图数据类型
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax

torch.manual_seed(0)
np.random.seed(0)

2.2 数据预处理,加载数据

分子原子和边允许的数据类型如下:

# allowable node and edge features
#实际上只使用了原子的特征和边的特征仅用到了两个
allowable_features = {
    'possible_atomic_num_list' : list(range(1, 119)), #元素种类
    'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], #电荷特征
    'possible_chirality_list' : [ #手性类型
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ],
    'possible_hybridization_list' : [ #轨道杂化类型
        Chem.rdchem.HybridizationType.S,
        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8], #H的数量
    'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6], #隐含价键数
    'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # 度
    'possible_bonds' : [  # 键的类型
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],
    'possible_bond_dirs' : [ # only for double bond stereo information #键的方向
        Chem.rdchem.BondDir.NONE,
        Chem.rdchem.BondDir.ENDUPRIGHT,
        Chem.rdchem.BondDir.ENDDOWNRIGHT
    ]
}

虽然,上述的节点(原子)和边(化学键)的允许的类型很多,但是文章中直接使用原子的序号和手性作为特征,而没有考虑节点的度,不饱和键数等作为特征。由于直接使用原子序号做嵌入,导致要嵌入到119维,因此,文章构造图的方法,可能是不合适的。

将rdkit的mol对象转化成Pyg的数据类型:节点特征保存在data.x,边的特征保存data.edge_attr,边保存在data.edge_index。其中,节点特征数为num_atom_features = 2,仅包括原子种类和手性。边的特征也为2,num_bond_features = 2,包括键的类型和方向。

def mol_to_graph_data_obj_simple(mol):
    """
    将rdkit对象转化为Pyg对象
    """
    # 节点,原子只有原子种类+手性标签
    num_atom_features = 2   # atom type,  chirality tag
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = [allowable_features['possible_atomic_num_list'].index(
            atom.GetAtomicNum())] + [allowable_features[
            'possible_chirality_list'].index(atom.GetChiralTag())]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)    

    # 边,键只有键的类型和键的方向
    num_bond_features = 2   # bond type, bond direction
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = [allowable_features['possible_bonds'].index(
                bond.GetBondType())] + [allowable_features[
                                            'possible_bond_dirs'].index(
                bond.GetBondDir())]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list),
                                 dtype=torch.long)
    else:   # mol has no bonds
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)  

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) # Data是Pyg的数据类型
    return data

基于PYG的Dataset构造MoleculeDataset类,将分子由SMILES转化为PYG图,并保存成PYG的Dataset类型。这里默认加载的是zinc250k数据集,保存在dataset文件夹内。其中使用到上面的mol_to_graph_data_obj_simple函数,将分子转化为PYG的图,然后通过self.collate(data_list) 整合成为Dataset。

class MoleculeDataset(InMemoryDataset):
    '''
    将zinc数据集加载成PYG的Dataset
    '''
    def __init__(self, root, dataset='zinc250k',
                 transform=None, pre_transform=None, 
                 pre_filter=None):
        
        print(dataset)
        self.dataset = dataset
        
        self.root = root
        super(MoleculeDataset, self).__init__(root, transform, pre_transform,
                                                 pre_filter) #要放在后面        
        
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property #返回原始文件列表
    def raw_file_names(self):
        file_name_list = os.listdir(self.raw_dir)
        return file_name_list
    
    @property #返回需要跳过的文件列表
    def processed_file_names(self):
        return 'geometric_data_processed.pt'
    
    def process(self):
        data_smiles_list = []
        data_list = []

        input_path = self.raw_paths[0]
        input_df = pd.read_csv(input_path, sep=',', compression='gzip',
                                   dtype='str')
        smiles_list = list(input_df['smiles'])
        zinc_id_list = list(input_df['zinc_id'])
        for i in range(len(smiles_list)):
            if i%1000==0:
                print(str(i)+'...')
            s = smiles_list[i]
            try:
                rdkit_mol = AllChem.MolFromSmiles(s, sanitize=True)
                if rdkit_mol != None:  # ignore invalid mol objects
                    # # convert aromatic bonds to double bonds
                    # Chem.SanitizeMol(rdkit_mol,sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
                    data = mol_to_graph_data_obj_simple(rdkit_mol)
                    # manually add mol id
                    id = int(zinc_id_list[i].split('ZINC')[1].lstrip('0'))
                    data.id = torch.tensor([id])  # id here is zinc id value, stripped of
                    # leading zeros
                    data_list.append(data)
                    data_smiles_list.append(smiles_list[i])
            except:
                continue
                
        #过滤器
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        #转换器,
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # write data_smiles_list in processed paths
        data_smiles_series = pd.Series(data_smiles_list)
        data_smiles_series.to_csv(os.path.join(self.processed_dir,
                                               'content_smiles.csv'), index=False,
                                  header=False)
        
        # InMemoryDataset的方法,将 torch_geometric.data.Data的list,转化为内部存储
        #这里设置的保存路径为processedpath[0]
        data, slices = self.collate(data_list) 
        torch.save((data, slices), self.processed_paths[0])
        
    #显示属性
    def __repr__(self):
        return '{}()'.format(self.dataname) 

正如在解析部分所述的,需要在分子图上随机抽取context grah。即,将一个分子的Pyg数据对象,添加context相关的特征,包括:k阶子图,及其索引;context图,及其索引; context anchor node节点,及其在context图中的索引;

处理方法:

分子 → mol对象 → Pyg图 → NX图 → k阶子图/context图 → context anchor node context → Pyg图特征

把上述过程写入ExtractSubstructureContextPair类中,每一个分子图添加了 :data.center_substruct_idx:中心节点
data.x_substruct:K阶子图节点特征
data.edge_attr_substruct:K阶子图边特征
data.edge_index_substruct:K阶子图边
data.x_context:context子图节点特征
data.edge_attr_context:context子图边特征
data.edge_index_context:context子图边
data.overlap_context_substruct_idx:context anchor node节点序号

class ExtractSubstructureContextPair:
    def __init__(self, k, l1, l2):
        """
        随机选择一个根节点,然后根据k,r1(l1), r2(l2)标记context graph
        """
        self.k = k #文献中的k
        self.l1 = l1 #文献中的r1
        self.l2 = l2 #文献中的r2

        #如果为0,标记为负数
        # for the special case of 0, addresses the quirk with
        # single_source_shortest_path_length
        if self.k == 0:
            self.k = -1
        if self.l1 == 0:
            self.l1 = -1
        if self.l2 == 0:
            self.l2 = -1

    def __call__(self, data, root_idx=None):
        """
        # Pyg数据对象中,添加如下特征:
        data.center_substruct_idx
        data.x_substruct
        data.edge_attr_substruct
        data.edge_index_substruct
        data.x_context
        data.edge_attr_context
        data.edge_index_context
        data.overlap_context_substruct_idx
        """
        num_atoms = data.x.size()[0] #节点数
        # 若无指定随机选择根节点
        if root_idx == None:
            root_idx = random.sample(range(num_atoms), 1)[0]

        #Pyg数据转为nx对象
        G = graph_data_obj_to_nx_simple(data)  # same ordering as input data obj

        # 根节点的K阶子图的节点索引
        substruct_node_idxes = nx.single_source_shortest_path_length(G,
                                                                     root_idx,
                                                                     self.k).keys()
        #如果有K阶子图
        if len(substruct_node_idxes) > 0:
            substruct_G = G.subgraph(substruct_node_idxes) #K阶子图
            substruct_G, substruct_node_map = reset_idxes(substruct_G)  #重置K阶子图的索引
            
            substruct_data = nx_to_graph_data_obj_simple(substruct_G) #将K阶子图由Nx图转换为Pyg图
            #将K阶子图特征添加到分子的PYG图中
            data.x_substruct = substruct_data.x
            data.edge_attr_substruct = substruct_data.edge_attr
            data.edge_index_substruct = substruct_data.edge_index
            #中心节点记住子图中的索引
            data.center_substruct_idx = torch.tensor([substruct_node_map[
                                                          root_idx]])  
        # r1子图
        l1_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l1).keys()
        # r2子图
        l2_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l2).keys()
        # r1-r2之间的共享节点
        context_node_idxes = set(l1_node_idxes).symmetric_difference(
            set(l2_node_idxes))
        # 如果存在r1-r2之间的共享节点,context graph
        if len(context_node_idxes) > 0:
            context_G = G.subgraph(context_node_idxes) # 提取context graph
            context_G, context_node_map = reset_idxes(context_G) # 重置context graph索引
            context_data = nx_to_graph_data_obj_simple(context_G) #context graph转为Pyg图类型
            #context graph添加到分子图中,做为特征保存
            data.x_context = context_data.x
            data.edge_attr_context = context_data.edge_attr
            data.edge_index_context = context_data.edge_index

        # 获取context anchor node,context图和K阶子图的共享节点
        context_substruct_overlap_idxes = list(set(
            context_node_idxes).intersection(set(substruct_node_idxes))) #context anchor node索引
        #如果存在 context anchor node
        if len(context_substruct_overlap_idxes) > 0:
            # 在context图中的记录context anchor node的索引
            context_substruct_overlap_idxes_reorder = [context_node_map[old_idx]
                                                       for
                                                       old_idx in
                                                       context_substruct_overlap_idxes]
            # need to convert the overlap node idxes, which is from the
            # original graph node ordering to the new context node ordering
            data.overlap_context_substruct_idx = \
                torch.tensor(context_substruct_overlap_idxes_reorder)

        return data

    def __repr__(self):
        return '{}(k={},l1={}, l2={})'.format(self.__class__.__name__, self.k,
                                              self.l1, self.l2)

在ExtractSubstructureContextPair涉及到Pyg图转化为NX图,处理完以后,又转化会PYG图,所以有以下几个支持函数:

# 将NX对象转换为Pyg数据对象
def nx_to_graph_data_obj_simple(G):
    """
    将NX图转换为Pyg数据对象,图
    """
    # atoms,原子节点
    num_atom_features = 2  # atom type,  chirality tag
    atom_features_list = []
    for _, node in G.nodes(data=True):
        atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds,边
    num_bond_features = 2  # bond type, bond direction
    #如果有边
    if len(G.edges()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for i, j, edge in G.edges(data=True):
            edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_features_list),
                                 dtype=torch.long)
    else:   
        # 没有边
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

# pyg图对象生成NX图对象
# 为了查找子图
def graph_data_obj_to_nx_simple(data):
    # data是PYG图
    G = nx.Graph()

    # atoms
    atom_features = data.x.cpu().numpy() #节点特征
    num_atoms = atom_features.shape[0] #节点数量
    #将节点特征添加到NX中
    for i in range(num_atoms):
        atomic_num_idx, chirality_tag_idx = atom_features[i]
        G.add_node(i, atom_num_idx=atomic_num_idx, chirality_tag_idx=chirality_tag_idx)
        pass

    #将边特征添加到NX中
    edge_index = data.edge_index.cpu().numpy()
    edge_attr = data.edge_attr.cpu().numpy()
    num_bonds = edge_index.shape[1]
    for j in range(0, num_bonds, 2):
        begin_idx = int(edge_index[0, j])
        end_idx = int(edge_index[1, j])
        bond_type_idx, bond_dir_idx = edge_attr[j] #边的特征
        if not G.has_edge(begin_idx, end_idx):
            #逐条添加边
            G.add_edge(begin_idx, end_idx, bond_type_idx=bond_type_idx,
                       bond_dir_idx=bond_dir_idx)
    return G

我们尝试使用MoleculeDataset构造数据集同时使用ExtractSubstructureContextPair类,提取每一个分子的context graph及其必要数据。这里使用ExtractSubstructureContextPair作为转化函数,提取分子的子结构等信息:

dataset = MoleculeDataset(root="dataset/zinc_standard_agent",
    dataset='zinc_standard_agent', 
    transform = ExtractSubstructureContextPair(3, 2, 6))
dataset[0]

结果如下:

关于Dataloader,也要做一些特殊的处理。因为,分子中关于K阶子图和context图的内容,特别是中心节点的标记,都是放置在图层面的特征中的。由于pytorch默认的DataLoader是不处理的,直接简单的堆叠,图是直接拼接起来。而context图直接拼接起来,会丢失原来的索引,导致不知道anchor节点、中心节点的序号,所以,要做特殊处理。主要是针对子图来操作。如下BatchSubstructContext类:

# 批次加载数据
# 多个分子图(pyg图数据)生成PYG的大图,用于批次训练
#将多个PYG图组成的list,生成一个不连接的大图,用于批次训练
class BatchSubstructContext(Data):

    """
    Specialized batching for substructure context pair!
    """

    def __init__(self, batch=None, **kwargs):
        super(BatchSubstructContext, self).__init__(**kwargs)
        self.batch = batch 

    @staticmethod
    def from_data_list(data_list):
        #keys = [set(data.keys) for data in data_list]
        #keys = list(set.union(*keys))
        #assert 'batch' not in keys

        batch = BatchSubstructContext() #自己引用自己的数据类
        keys = ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct", "overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]

        for key in keys:
            #print(key)
            batch[key] = []

        #batch.batch = []
        #used for pooling the context
        #记录批次中的anchor node属于哪一图的索引
        batch.batch_overlapped_context = [] #记录批次中anchor node,节点数量等于所有分子的anchor node之和,相当于索引
        batch.overlapped_context_size = [] #记录批次中每一张anchor node的大小

        cumsum_main = 0
        cumsum_substruct = 0
        cumsum_context = 0

        i = 0
        
        for data in data_list:
            #If there is no context, just skip!!
            # 如果没有context图,跳过!
            if hasattr(data, "x_context"):
                num_nodes = data.num_nodes
                num_nodes_substruct = len(data.x_substruct) #k阶子图节点数
                num_nodes_context = len(data.x_context) # context图节点数

                #batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
                #记录context anchor节点,索引
                batch.batch_overlapped_context.append(torch.full((len(data.overlap_context_substruct_idx), ), i, dtype=torch.long))
                batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx))

                #批次中的K阶子图
                for key in ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct"]:
                    item = data[key]
                    item = item + cumsum_substruct if batch.cumsum(key, item) else item
                    batch[key].append(item)

                ###batching for the context graph
                #批次中的context图
                for key in ["overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]:
                    item = data[key] #特征值
                    #key是特征键
                    #看批次中如果已经有了特征键,则新分子的特征值加在list后面,否则新建值
                    item = item + cumsum_context if batch.cumsum(key, item) else item
                    batch[key].append(item)

                cumsum_main += num_nodes #总节点数
                cumsum_substruct += num_nodes_substruct # K阶子图总节点数  
                cumsum_context += num_nodes_context # context图节点数
                i += 1 #分子数

        for key in keys:
            batch[key] = torch.cat(
                batch[key], dim=batch.cat_dim(key)) #每一个特征叠在一起,生成批次特征
        #batch.batch = torch.cat(batch.batch, dim=-1)
        batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1)
        batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size)

        return batch.contiguous()

    def cat_dim(self, key):
        #注意,边序号的叠加方式和特征的叠加方式不一样,因为边是(2,n), 而特征是(n, )
        return -1 if key in ["edge_index", "edge_index_substruct", "edge_index_context"] else 0

    def cumsum(self, key, item):
        #查看某个特征键是不是在批次特征键中
        return key in ["edge_index", "edge_index_substruct", "edge_index_context", "overlap_context_substruct_idx", "center_substruct_idx"]

    @property
    def num_graphs(self):
        """Returns the number of graphs in the batch."""
        return self.batch[-1].item() + 1

# context预训练批次的dataloader
class DataLoaderSubstructContext(torch.utils.data.DataLoader):

    def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
        super(DataLoaderSubstructContext, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list),
            **kwargs)

2.3 GIN模型

文献可能是比较老,使用的是非常简单的GIN的模型,还自己写了GIN层,现在可以直接从pytorch.nn.models中直接调用的,输出的是图中节点的嵌入向量。

class GINConv(MessagePassing):
    """
    文献中的GIN模型
    """
    def __init__(self, emb_dim, num_bond_type=5, num_bond_direction=3, aggr = "add"):
        super(GINConv, self).__init__(aggr = "add")
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))[0]
        edge_index = edge_index.long()
        
        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])        
        return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

class GNN(torch.nn.Module):
    """
    GIN模型
    """
    def __init__(self, num_layer, emb_dim, num_atom_type=120, num_chirality_tag=4, JK = "last", drop_ratio = 0.5):
        super(GNN, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)

        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.gnns.append(GINConv(emb_dim, aggr = "add"))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    #def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

根据文章思路,有俩个模型,分别训练在大图上,还有一个训练在context图上。两个模型的定义仅仅在层数的差别:

#主模型,在分子图上训练
model = GNN(7,512)
model = model.to(device)

#context模型,在context图上训练
context_model = GNN(4,512)
context_model = context_model.to(device)

2.4 训练过程

首先写一个单次训练函数train。

注意:substruct_rep是分子图上中心节点的嵌入向量;overlapped_node_rep是context图中context anchor node的嵌入向量。我们希望二者之间相似,所以他们是pos正样本。而expanded_substruct_rep 是在批次中,随机抽取的负样本的嵌入向量,类似于context anchor node,但是我们希望这些嵌入向量与中心节点的嵌入向量不相似。这样子神经网络就学会了相似的子结构范围内的嵌入向量相似。当然,模型的结构要求类似,所以分子图上的主模型和context图上的模型只有层数的差别。

其后会看到,损失函数使用的是torch.nn.BCEWithLogitsLoss(),交叉熵。

#单次epoch训练
def train(model, context_model, loader, optimizer_substruct, optimizer_context, criterion, device):
    # 
    #     这里,我们使用skipgram的方法,即用中心节点来预测周围的节点,
    #     在skip-gram当中,每个节点都要收到周围的节点的影响,每个节点在作为中心节点的时候,都要进行K次的预测、调整。
    #     因此, 当数据量较少,或者节点为生僻节点出现次数较少时, 
    #     这种多次的调整会使得节点向量相对的更加准确。
    # 
    model.train()
    context_model.train()

    balanced_loss_accum = 0
    acc_accum = 0

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        # k阶子图的根节点的表示
        substruct_rep = model(batch.x_substruct, 
                        batch.edge_index_substruct, 
                        batch.edge_attr_substruct)[batch.center_substruct_idx]
        # context子图的context anchor节点的表示
        overlapped_node_rep = context_model(batch.x_context, 
                                batch.edge_index_context, 
                                batch.edge_attr_context)[batch.overlap_context_substruct_idx] 
        # 有几个根节点,就提取几个根节点周围的节点,简称周围节点
        expanded_substruct_rep = torch.cat([substruct_rep[i].repeat((batch.overlapped_context_size[i],1)) 
                                    for i in range(len(substruct_rep))], dim = 0)
        # 根节点与周围节点做内积, 相似度
        pred_pos = torch.sum(expanded_substruct_rep * overlapped_node_rep, dim = 1)
        #负样本的表示
        shifted_expanded_substruct_rep = []        
        neg_samples = 1 #负样本比例
        for i in range(neg_samples):
            #取出与根节点一样的周围节点,带shift的形式,负样本
            shifted_substruct_rep = substruct_rep[cycle_index(len(substruct_rep), i+1)]
            #负样本的节点表示,放在一起
            shifted_expanded_substruct_rep.append(torch.cat([shifted_substruct_rep[i].repeat(
                    (batch.overlapped_context_size[i],1)) for i in range(len(shifted_substruct_rep))], dim = 0))
        #所有的负样本放在一起,Tensor
        shifted_expanded_substruct_rep = torch.cat(shifted_expanded_substruct_rep, dim = 0)
        # 根节点与负样本节点的相似度
        pred_neg = torch.sum(shifted_expanded_substruct_rep * overlapped_node_rep.repeat((neg_samples, 1)), dim = 1)
        #正\负样本损失
        loss_pos = criterion(pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double())
        loss_neg = criterion(pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double())
        #优化器梯度归零
        optimizer_substruct.zero_grad()
        optimizer_context.zero_grad()
        #损失相加
        loss = loss_pos + neg_samples*loss_neg
        #生成梯度
        loss.backward()
        #优化器更新参数
        optimizer_substruct.step()
        optimizer_context.step()
        #损失
        balanced_loss_accum += float(loss_pos.detach().cpu().item() + loss_neg.detach().cpu().item())
        #精度
        acc_accum += 0.5* (float(torch.sum(pred_pos > 0).detach().cpu().item())/len(pred_pos) 
                            + float(torch.sum(pred_neg < 0).detach().cpu().item())/len(pred_neg))

    #返回该批次的精度
    return balanced_loss_accum/step, acc_accum/step

实际训练过程代码:


if __name__ =='__main__':
    #运行设备
    device = torch.device(torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu'))
    
    #主模型
    model = GNN(7,512)
    model = model.to(device)
    #conetxt模型
    context_model = GNN(4,512)
    context_model = context_model.to(device)

    #加载数据
    dataset = MoleculeDataset(
        root="dataset/zinc_standard_agent",
        dataset='zinc_standard_agent',
        transform = ExtractSubstructureContextPair(3, 2, 6))
    loader = DataLoaderSubstructContext(dataset, batch_size=512, shuffle=True, num_workers=8, pin_memory=True)
    # batch_size=1024 大概是16G的内存

    # 优化器
    optimizer_substruct = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    optimizer_context = optim.Adam(context_model.parameters(), lr=0.001, weight_decay=0.0001)

    # 损失函数
    criterion = torch.nn.BCEWithLogitsLoss()

    # 逐个epoch训练
    #迭代次数
    epcohs = 100
    #记录
    log_loss = []
    log_acc = []

    model = model.to(device)
    context_model = context_model.to(device)

    for epoch in range(epcohs):
        print("====epoch " + str(epoch))        
        train_loss, train_acc = train(model, context_model, loader, optimizer_substruct, optimizer_context, criterion, device)

        log_loss.append(train_loss)
        log_acc.append(train_acc)
        # 保存损失和指标
        np.save("Context_Pretrain__loss.npy", log_loss)
        np.save("Context_Pretrain_log_acc.npy", log_acc)

        print('Epoch:{},loss:{}, acc:{}'.format(
            epoch, train_loss, train_acc))
        # 保存模型,由于模型训练时长很长,所以每次都要保存一下
        torch.save(model.state_dict(), "Context_Pretrain_GIN_para.pth")
        torch.save(model, "Context_Pretrain_GIN.pth")
        torch.save(context_model, 'context_model.pth')

三、模型训练结果

经过训练,我们得到一个与训练好的GIN模型,保存在Context_Pretrain_GIN.pth中。

训练时间比较长,使用V100的显卡,每次迭代需要15分钟,100个循环大概需要三天。训练过程的损失函数如下:

四、运行环境

工作目录如下:

.
├── Context_Pretrain_GIN.pth
├── Context_Pretrain_GIN_para.pth
├── Context_Pretrain__loss.npy
├── Context_Pretrain_log_acc.npy
├── Pyg_pretrain.yml #运行环境
├── context_model.pth
├── context_pretrain.py
├── dataset #数据集
├── pretrain_context_预训练损失函数.ipynb
└── 说明.text

运行环境查看Pyg_pretrain.yml文件。上文的代码是缺少必要的支持函数的,需要直接运行的,需要下载代码包。

代码及数据的下载链接:

链接:https://pan.baidu.com/s/14cxHjU2zwzkqPfwwfuSx0Q 
提取码:795y

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值