PaGE-Link:Path-based Graph Neural Network Explanation for Heterogeneous Link Prediction,链接预测,边预测和解释

《PaGE-Link: Path-based Graph Neural Network Explanation for Heterogeneous Link Prediction》

1. 论文概要

作者提出了用于异构图链路预测(边预测)的基于路径的GNN解释(PaGE-Link),通俗易懂的讲就是解释一下为什么预测的边是存在的,让大家得以信服(因为黑盒理论,很多时候大家对神经网络的结果不认可,原因是缺乏解释性)。

该方法生成预测连接的解释, 具有模型可扩展性,并且可以用于异构图处理。PaGE-Link可以作为连接节点对的路径生成的解释,首先预测两个节点之间的连接,其次生成相关解释。

个人感觉就是把重要的边连接的依据特征给筛选出来,生成一堆非直接路径作为两个节点的链接依据。 

 

图中做了相关举例:针对user1和item1的链接,理论依据的路径1是u1 ->i2->a1->i1,路径2是u1 ->i3->u2->i1。其实还有很多依据,只不过模型通过学习给出了两个最可能的依据。这个数量是自己可以设定的。

 

鉴于GNN的解释对LP的重要性和挑战,作者将其表述为一个事后和实例级的解释问题,并以连接源节点和目标节点的重要路径的形式对其进行解释。 

 

 

相关的工作,自己可以看一下相关方便的其它研究。

 

 

 

 

 

 

 

提出的方法主要有两部分:K剪枝,这个是为了减小模型的搜索规模,减小时间复杂度;掩码学习,这个主要是探索权重比较高的边,以作为路径选择的特征依据 。

 

 

 

 

 

 

 

 

 

 

 

 

 

  

模型解释的最终目的是提高模型透明度,帮助人类决策。因此, 人工评估是评估解释器有效性的最佳方式,这在之前的工作中一 直是标准的评估方法。我们通过从AugCitation的测试集 中随机选取100个预测链接进行人工评估,并使用GNNExp-Link、 PGExp-Link和PaGE-Link为每个链接生成解释。我们设计了一个带有单项选择题的调查。在每个问题中,我们用图结构和节点/边类型信息向受访者展示预测的链接和这三种解释,类似于图5, 但不包括方法名称。该调查被发送给研究生、博士后、工程师、 研究科学家和教授,包括具有和不具有GNNs背景知识的人。我们要求受访者"请选择最佳解释‘为什么模型预测这个作者会喜欢推荐的论文?’”。每个问题至少收集了来自不同人的三个答案。总共收集了340个评价,其中78.79%的评价认为PaGE-Link的解释是最好的。

2. 开源代码

data_processing.py

数据处理过程

import dgl
import torch
import scipy.sparse as sp
import numpy as np
from utils import eids_split, remove_all_edges_of_etype, get_num_nodes_dict

def process_data(g, 
                 val_ratio, 
                 test_ratio,
                 src_ntype = 'author', 
                 tgt_ntype = 'paper',
                 pred_etype = 'likes',
                 neg='pred_etype_neg'):
    '''
    Parameters
    ----------
    g : dgl graph
    
    val_ratio : float
    
    test_ratio : float
    
    src_ntype: string
        source node type
    
    tgt_ntype: string
        target node type

    neg: string
        One of ['pred_etype_neg', 'src_tgt_neg'], different negative sampling modes. See below.
    
    Returns
    ----------
    mp_g: 
        graph for message passing.
    
    graphs containing positive edges and negative edges for train, valid, and test
    '''
    
    u, v = g.edges(etype=pred_etype)
    src_N = g.num_nodes(src_ntype)
    tgt_N = g.num_nodes(tgt_ntype)

    M = u.shape[0] # number of directed edges
    eids = torch.arange(M)
    train_pos_eids, val_pos_eids, test_pos_eids = eids_split(eids, val_ratio, test_ratio)

    train_pos_u, train_pos_v = u[train_pos_eids], v[train_pos_eids]
    val_pos_u, val_pos_v = u[val_pos_eids], v[val_pos_eids]
    test_pos_u, test_pos_v = u[test_pos_eids], v[test_pos_eids]

    if neg == 'pred_etype_neg':
        # Edges not in pred_etype as negative edges
        adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())), shape=(src_N, tgt_N))
        adj_neg = 1 - adj.todense()
        neg_u, neg_v = np.where(adj_neg != 0)
    elif neg == 'src_tgt_neg':
        # Edges not connecting src and tgt as negative edges
        
        # Collect all edges between the src and tgt
        src_tgt_indices = []
        for etype in g.canonical_etypes:
            if etype[0] == src_ntype and etype[2] == tgt_ntype:
                adj = g.adj(etype=etype)
                src_tgt_index = adj.coalesce().indices()        
                src_tgt_indices += [src_tgt_index]
        src_tgt_u, src_tgt_v = torch.cat(src_tgt_indices, dim=1)

        # Find all negative edges that are not in src_tgt_indices
        adj = sp.coo_matrix((np.ones(len(src_tgt_u)), (src_tgt_u.numpy(), src_tgt_v.numpy())), shape=(src_N, tgt_N))
        adj_neg = 1 - adj.todense()
        neg_u, neg_v = np.where(adj_neg != 0)
    else:
        raise ValueError('Unknow negative argument')
        
    neg_eids = np.random.choice(neg_u.shape[0], min(neg_u.shape[0], M), replace=False)
    train_neg_eids, val_neg_eids, test_neg_eids = eids_split(torch.from_numpy(neg_eids), val_ratio, test_ratio)

    # train_neg_u, train_neg_v = neg_u[train_neg_eids], neg_v[train_neg_eids]
    # val_neg_u, val_neg_v = neg_u[val_neg_eids], neg_v[val_neg_eids]
    # test_neg_u, test_neg_v = neg_u[test_neg_eids], neg_v[test_neg_eids]

    # Avoid losing dimension in single number slicing
    train_neg_u, train_neg_v = np.take(neg_u, train_neg_eids), np.take(neg_v, train_neg_eids)
    val_neg_u, val_neg_v = np.take(neg_u, val_neg_eids),np.take(neg_v, val_neg_eids)
    test_neg_u, test_neg_v = np.take(neg_u, test_neg_eids), np.take(neg_v, test_neg_eids)
    
    # Construct graphs
    pred_can_etype = (src_ntype, pred_etype, tgt_ntype)
    num_nodes_dict = get_num_nodes_dict(g)
    
    train_pos_g = dgl.heterograph({pred_can_etype: (train_pos_u, train_pos_v)}, num_nodes_dict)
    train_neg_g = dgl.heterograph({pred_can_etype: (train_neg_u, train_neg_v)}, num_nodes_dict)
    val_pos_g = dgl.heterograph({pred_can_etype: (val_pos_u, val_pos_v)}, num_nodes_dict)
    val_neg_g = dgl.heterograph({pred_can_etype: (val_neg_u, val_neg_v)}, num_nodes_dict)
    test_pos_g = dgl.heterograph({pred_can_etype: (test_pos_u, test_pos_v)}, num_nodes_dict)

    test_neg_g = dgl.heterograph({pred_can_etype: (test_neg_u, test_neg_v)}, num_nodes_dict)
    
    mp_g = remove_all_edges_of_etype(g, pred_etype) # Remove pred_etype edges but keep nodes
    return mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g


def load_dataset(dataset_dir, dataset_name, val_ratio, test_ratio):
    '''
    Parameters
    ----------
    dataset_dir : string
        dataset directory
    
    dataset_name : string
    
    val_ratio : float
    
    test_ratio : float

    Returns:
    ----------
    g: dgl graph
        The original graph

    processed_g: tuple of seven dgl graphs
        The outputs of the function `process_data`, 
        which includes g for message passing, train, valid, and test
        
    pred_pair_to_edge_labels : dict
        key=((source node type, source node id), (target node type, target node id))
        value=dict, {cannonical edge type: (source node ids, target node ids)}
        
    pred_pair_to_path_labels : dict 
        key=((source node type, source node id), (target node type, target node id))
        value=list of lists, each list contains (cannonical edge type, source node ids, target node ids)
    '''
    graph_saving_path = f'{dataset_dir}/{dataset_name}'
    graph_list, _ = dgl.load_graphs(graph_saving_path)
    pred_pair_to_edge_labels = torch.load(f'{graph_saving_path}_pred_pair_to_edge_labels')
    pred_pair_to_path_labels = torch.load(f'{graph_saving_path}_pred_pair_to_path_labels')
    g = graph_list[0]
    if 'synthetic' in dataset_name:
        src_ntype, tgt_ntype = 'user', 'item'
    elif 'citation' in dataset_name:
        src_ntype, tgt_ntype = 'author', 'paper'

    pred_etype = 'likes'
    neg = 'src_tgt_neg'
    processed_g = process_data(g, val_ratio, test_ratio, src_ntype, tgt_ntype, pred_etype, neg)
    return g, processed_g, pred_pair_to_edge_labels, pred_pair_to_path_labels

eval_explanations.py

import torch
import numpy as np
import argparse
import pickle
from collections import defaultdict
from pathlib import Path
from tqdm.auto import tqdm

from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel
from utils import set_config_args, get_comp_g_edge_labels, get_comp_g_path_labels
from utils import hetero_src_tgt_khop_in_subgraph, eval_edge_mask_auc, eval_edge_mask_topk_path_hit

parser = argparse.ArgumentParser(description='Explain link predictor')
'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1) 
parser.add_argument('--test_ratio', type=float, default=0.2)
parser.add_argument('--max_num_samples', type=int, default=-1, 
                    help='maximum number of samples to explain, for fast testing. Use all if -1')

'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)
parser.add_argument('--saved_model_dir', type=str, default='saved_models')
parser.add_argument('--saved_model_name', type=str, default='')

'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
                   help='operation passed to dgl.EdgePredictor')

'''
Explanation args
'''
parser.add_argument('--num_hops', type=int, default=2, help='computation graph number of hops') 
parser.add_argument('--saved_explanation_dir', type=str, default='saved_explanations',
                    help='directory of saved explanations')
parser.add_argument('--eval_explainer_names', nargs='+', default=['pagelink'],
                    help='name of explainers to evaluate') 
parser.add_argument('--eval_path_hit', default=False, action='store_true', 
                    help='Whether to save the explanation') 
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')

args = parser.parse_args()

if args.config_path:
    args = set_config_args(args, args.config_path, args.dataset_name, 'train_eval')

if 'citation' in args.dataset_name:
    args.src_ntype = 'author'
    args.tgt_ntype = 'paper'

elif 'synthetic' in args.dataset_name:
    args.src_ntype = 'user'
    args.tgt_ntype = 'item'    
    
if args.link_pred_op in ['cat']:
    pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
    pred_kwargs = {}

g, processed_g, pred_pair_to_edge_labels, pred_pair_to_path_labels = load_dataset(args.dataset_dir,
                                                                                  args.dataset_name,
                                                                                  args.valid_ratio,
                                                                                  args.test_ratio)
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g for g in processed_g]
encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)

if not args.saved_model_name:
    args.saved_model_name = f'{args.dataset_name}_model'

state = torch.load(f'{args.saved_model_dir}/{args.saved_model_name}.pth', map_location='cpu')
model.load_state_dict(state)    

test_src_nids, test_tgt_nids = test_pos_g.edges()
comp_graphs = defaultdict(list)
comp_g_labels = defaultdict(list)
test_ids = range(test_src_nids.shape[0])
if args.max_num_samples > 0:
    test_ids = test_ids[:args.max_num_samples]

for i in tqdm(test_ids):
    # Get the k-hop subgraph
    src_nid, tgt_nid = test_src_nids[i], test_tgt_nids[i]
    comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids = hetero_src_tgt_khop_in_subgraph(args.src_ntype, 
                                                                                               src_nid,
                                                                                               args.tgt_ntype,
                                                                                               tgt_nid,
                                                                                               mp_g,
                                                                                               args.num_hops)

    with torch.no_grad():
        pred = model(comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids).sigmoid().item() > 0.5

    if pred:
        src_tgt = ((args.src_ntype, int(src_nid)), (args.tgt_ntype, int(tgt_nid)))
        comp_graphs[src_tgt] = [comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids]

        # Get labels with subgraph nids and eids 
        edge_labels = pred_pair_to_edge_labels[src_tgt]
        comp_g_edge_labels = get_comp_g_edge_labels(comp_g, edge_labels)

        path_labels = pred_pair_to_path_labels[src_tgt]
        comp_g_path_labels = get_comp_g_path_labels(comp_g, path_labels)

        comp_g_labels[src_tgt] = [comp_g_edge_labels, comp_g_path_labels]

explanation_masks = {}
for explainer in args.eval_explainer_names:
    saved_explanation_mask = f'{explainer}_{args.saved_model_name}_pred_edge_to_comp_g_edge_mask'
    saved_file = Path.cwd().joinpath(args.saved_explanation_dir, saved_explanation_mask)
    with open(saved_file, "rb") as f:
        explanation_masks[explainer] = pickle.load(f)

print('Dataset:', args.dataset_name)
for explainer in args.eval_explainer_names:
    print(explainer)
    print('-'*30)
    pred_edge_to_comp_g_edge_mask = explanation_masks[explainer]
    
    mask_auc_list = []
    for src_tgt in comp_graphs:
        comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids, = comp_graphs[src_tgt]
        comp_g_edge_labels, comp_g_path_labels = comp_g_labels[src_tgt]
        comp_g_edge_mask_dict = pred_edge_to_comp_g_edge_mask[src_tgt]
        mask_auc = eval_edge_mask_auc(comp_g_edge_mask_dict, comp_g_edge_labels)
        mask_auc_list += [mask_auc]
      
    avg_auc = np.mean(mask_auc_list)
    
    # Print
    np.set_printoptions(precision=4, suppress=True)
    print(f'Average Mask-AUC: {avg_auc : .4f}')
    
    print('-'*30, '\n')

if args.eval_path_hit:
    topks = [3, 5, 10, 20, 50, 100, 200]
    for explainer in args.eval_explainer_names:
        print(explainer)
        print('-'*30)
        pred_edge_to_comp_g_edge_mask = explanation_masks[explainer]

        explainer_to_topk_path_hit = defaultdict(list)
        for src_tgt in comp_graphs:
            comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids, = comp_graphs[src_tgt]
            comp_g_path_labels = comp_g_labels[src_tgt][1]
            comp_g_edge_mask_dict = pred_edge_to_comp_g_edge_mask[src_tgt]
            topk_to_path_hit = eval_edge_mask_topk_path_hit(comp_g_edge_mask_dict, comp_g_path_labels, topks)

            for topk in topk_to_path_hit:
                explainer_to_topk_path_hit[topk] += [topk_to_path_hit[topk]]        

        # Take average
        explainer_to_topk_path_hit_rate = defaultdict(list)
        for topk in explainer_to_topk_path_hit:
            metric = np.array(explainer_to_topk_path_hit[topk])
            explainer_to_topk_path_hit_rate[topk] = metric.mean(0)

        # Print
        np.set_printoptions(precision=4, suppress=True)
        for k, hr in explainer_to_topk_path_hit_rate.items():
            print(f'k: {k :3} | Path HR: {hr.item(): .4f}')

        print('-'*30, '\n')

explainer.py

import dgl
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
from collections import defaultdict
from utils import get_ntype_hetero_nids_to_homo_nids, get_homo_nids_to_ntype_hetero_nids, get_ntype_pairs_to_cannonical_etypes
from utils import hetero_src_tgt_khop_in_subgraph, get_neg_path_score_func, k_shortest_paths_with_max_length

def get_edge_mask_dict(ghetero):
    '''
    Create a dictionary mapping etypes to learnable edge masks 
            
    Parameters
    ----------
    ghetero : heterogeneous dgl graph.

    Return
    ----------
    edge_mask_dict : dictionary
        key=etype, value=torch.nn.Parameter with size number of etype edges
    '''
    device = ghetero.device
    edge_mask_dict = {}
    for etype in ghetero.canonical_etypes:
        num_edges = ghetero.num_edges(etype)
        num_nodes = ghetero.edge_type_subgraph([etype]).num_nodes()

        std = torch.nn.init.calculate_gain('relu') * np.sqrt(2.0 / (2 * num_nodes))
        edge_mask_dict[etype] = torch.nn.Parameter(torch.randn(num_edges, device=device) * std)
    return edge_mask_dict

def remove_edges_of_high_degree_nodes(ghomo, max_degree=10, always_preserve=[]):
    '''
    For all the nodes with degree higher than `max_degree`, 
    except nodes in `always_preserve`, remove their edges. 
    
    Parameters
    ----------
    ghomo : dgl homogeneous graph
    
    max_degree : int
    
    always_preserve : iterable
        These nodes won't be pruned.
    
    Returns
    -------
    low_degree_ghomo : dgl homogeneous graph
        Pruned graph with edges of high degree nodes removed

    '''
    d = ghomo.in_degrees()
    high_degree_mask = d > max_degree
    
    # preserve nodes
    high_degree_mask[always_preserve] = False    

    high_degree_nids = ghomo.nodes()[high_degree_mask]
    u, v = ghomo.edges()
    high_degree_edge_mask = torch.isin(u, high_degree_nids) | torch.isin(v, high_degree_nids)
    high_degree_u, high_degree_v = u[high_degree_edge_mask], v[high_degree_edge_mask]
    high_degree_eids = ghomo.edge_ids(high_degree_u, high_degree_v)
    low_degree_ghomo = dgl.remove_edges(ghomo, high_degree_eids)
    
    return low_degree_ghomo


def remove_edges_except_k_core_graph(ghomo, k, always_preserve=[]):
    '''
    Find the `k`-core of `ghomo`.
    Only isolate the low degree nodes by removing theirs edges
    instead of removing the nodes, so node ids can be kept.
    
    Parameters
    ----------
    ghomo : dgl homogeneous graph
    
    k : int
    
    always_preserve : iterable
        These nodes won't be pruned.
    
    Returns
    -------
    k_core_ghomo : dgl homogeneous graph
        The k-core graph
    '''
    k_core_ghomo = ghomo
    degrees = k_core_ghomo.in_degrees()
    k_core_mask = (degrees > 0) & (degrees < k)
    k_core_mask[always_preserve] = False
    
    while k_core_mask.any():
        k_core_nids = k_core_ghomo.nodes()[k_core_mask]
        
        u, v = k_core_ghomo.edges()
        k_core_edge_mask = torch.isin(u, k_core_nids) | torch.isin(v, k_core_nids)
        k_core_u, k_core_v = u[k_core_edge_mask], v[k_core_edge_mask]
        k_core_eids = k_core_ghomo.edge_ids(k_core_u, k_core_v)

        k_core_ghomo = dgl.remove_edges(k_core_ghomo, k_core_eids)
        
        degrees = k_core_ghomo.in_degrees()
        k_core_mask = (degrees > 0) & (degrees < k)
        k_core_mask[always_preserve] = False

    return k_core_ghomo

def get_eids_on_paths(paths, ghomo):
    '''
    Collect all edge ids on the paths
    
    Note: The current version is a list version. An edge may be collected multiple times
    A different version is a set version where an edge can only contribute one time 
    even it appears in multiple paths
    
    Parameters
    ----------
    ghomo : dgl homogeneous graph
    
    Returns
    -------
    paths: list of lists
        Each list contains (source node ids, target node ids)
        
    '''
    row, col = ghomo.edges()
    eids = []
    for path in paths:
        for i in range(len(path)-1):
            eids += [((row == path[i]) & (col == path[i+1])).nonzero().item()]
            
    return torch.LongTensor(eids)

def comp_g_paths_to_paths(comp_g, comp_g_paths):
    paths = []
    g_nids = comp_g.ndata[dgl.NID]
    for comp_g_path in comp_g_paths:
        path = []
        for can_etype, u, v in comp_g_path:
            u_ntype, _, v_ntype = can_etype
            path += [(can_etype, g_nids[u_ntype][u].item(), g_nids[v_ntype][v].item())]
        paths += [path]
    return paths


class PaGELink(nn.Module):
    """Path-based GNN Explanation for Heterogeneous Link Prediction (PaGELink)
    
    Some methods are adapted from the DGL GNNExplainer implementation
    https://docs.dgl.ai/en/0.8.x/_modules/dgl/nn/pytorch/explain/gnnexplainer.html#GNNExplainer
    
    Parameters
    ----------
    model : nn.Module
        The GNN-based link prediction model to explain.

        * The required arguments of its forward function are source node id, target node id,
          graph, and feature ids. The feature ids are for selecting input node features.
        * It should also optionally take an eweight argument for edge weights
          and multiply the messages by the weights during message passing.
        * The output of its forward function is the logits in (-inf, inf) for the 
          predicted link.
    lr : float, optional
        The learning rate to use, default to 0.01.
    num_epochs : int, optional
        The number of epochs to train.
    alpha1 : float, optional
        A higher value will make the explanation edge masks more sparse by decreasing
        the sum of the edge mask.
    alpha2 : float, optional
        A higher value will make the explanation edge masks more discrete by decreasing
        the entropy of the edge mask.
    alpha : float, optional
        A higher value will make edges on high-quality paths to have higher weights
    beta : float, optional
        A higher value will make edges off high-quality paths to have lower weights
    log : bool, optional
        If True, it will log the computation process, default to True.
    """
    def __init__(self,
                 model,
                 lr=0.001,
                 num_epochs=100,
                 alpha=1.0,
                 beta=1.0,
                 log=False):
        super(PaGELink, self).__init__()
        self.model = model
        self.src_ntype = model.src_ntype
        self.tgt_ntype = model.tgt_ntype
        
        self.lr = lr
        self.num_epochs = num_epochs
        self.alpha = alpha
        self.beta = beta
        self.log = log
        
        self.all_loss = defaultdict(list)

    def _init_masks(self, ghetero):
        """Initialize the learnable edge mask.

        Parameters
        ----------
        graph : DGLGraph
            Input graph.

        Returns
        -------
        edge_mask_dict : dict
            key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
        """
        return get_edge_mask_dict(ghetero)
    

    def _prune_graph(self, ghetero, prune_max_degree=-1, k_core=2, always_preserve=[]):
        # Prune edges by (optionally) removing edges of high degree nodes and extracting k-core
        # The pruning is computed on the homogeneous graph, i.e., ignoring node/edge types
        ghomo = dgl.to_homogeneous(ghetero)
        device = ghetero.device
        ghomo.edata['eid_before_prune'] = torch.arange(ghomo.num_edges()).to(device)
        
        if prune_max_degree > 0:
            max_degree_pruned_ghomo = remove_edges_of_high_degree_nodes(ghomo, prune_max_degree, always_preserve)
            k_core_ghomo = remove_edges_except_k_core_graph(max_degree_pruned_ghomo, k_core, always_preserve)
            
            if k_core_ghomo.num_edges() <= 0: # no k-core found
                pruned_ghomo = max_degree_pruned_ghomo
            else:
                pruned_ghomo = k_core_ghomo
        else:
            k_core_ghomo = remove_edges_except_k_core_graph(ghomo, k_core, always_preserve)
            if k_core_ghomo.num_edges() <= 0: # no k-core found
                pruned_ghomo = ghomo
            else:
                pruned_ghomo = k_core_ghomo
        
        pruned_ghomo_eids = pruned_ghomo.edata['eid_before_prune']
        pruned_ghomo_eid_mask = torch.zeros(ghomo.num_edges()).bool()
        pruned_ghomo_eid_mask[pruned_ghomo_eids] = True

        # Apply the pruning result on the heterogeneous graph
        etypes_to_pruned_ghetero_eid_masks = {}
        pruned_ghetero = ghetero
        cum_num_edges = 0
        for etype in ghetero.canonical_etypes:
            num_edges = ghetero.num_edges(etype=etype)
            pruned_ghetero_eid_mask = pruned_ghomo_eid_mask[cum_num_edges:cum_num_edges+num_edges]
            etypes_to_pruned_ghetero_eid_masks[etype] = pruned_ghetero_eid_mask

            remove_ghetero_eids = (~ pruned_ghetero_eid_mask).nonzero().view(-1).to(device)
            pruned_ghetero = dgl.remove_edges(pruned_ghetero, eids=remove_ghetero_eids, etype=etype)

            cum_num_edges += num_edges
                
        return pruned_ghetero, etypes_to_pruned_ghetero_eid_masks
        
        
    def path_loss(self, src_nid, tgt_nid, g, eweights, num_paths=5):
        """Compute the path loss.

        Parameters
        ----------
        src_nid : int
            source node id

        tgt_nid : int
            target node id

        g : dgl graph

        eweights : Tensor
            Edge weights with shape equals the number of edges.
            
        num_paths : int
            Number of paths to compute path loss on

        Returns
        -------
        loss : Tensor
            The path loss
        """
        neg_path_score_func = get_neg_path_score_func(g, 'eweight', [src_nid, tgt_nid])
        paths = k_shortest_paths_with_max_length(g, 
                                                 src_nid, 
                                                 tgt_nid, 
                                                 weight=neg_path_score_func, 
                                                 k=num_paths)

        eids_on_path = get_eids_on_paths(paths, g)

        if eids_on_path.nelement() > 0:
            loss_on_path = - eweights[eids_on_path].mean()
        else:
            loss_on_path = 0

        eids_off_path_mask = ~torch.isin(torch.arange(eweights.shape[0]), eids_on_path)
        if eids_off_path_mask.any():
            loss_off_path = eweights[eids_off_path_mask].mean()
        else:
            loss_off_path = 0

        loss = self.alpha * loss_on_path + self.beta * loss_off_path 

        self.all_loss['loss_on_path'] += [float(loss_on_path)]
        self.all_loss['loss_off_path'] += [float(loss_off_path)]

        return loss   

    
    def get_edge_mask(self, 
                      src_nid, 
                      tgt_nid, 
                      ghetero, 
                      feat_nids, 
                      prune_max_degree=-1,
                      k_core=2, 
                      prune_graph=True,
                      with_path_loss=True):

        """Learning the edge mask dict.   
        
        Parameters
        ----------
        see the `explain` method.
        
        Returns
        -------
        edge_mask_dict : dict
            key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
        """

        self.model.eval()
        device = ghetero.device
        
        ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)    
        homo_src_nid = ntype_hetero_nids_to_homo_nids[(self.src_ntype, int(src_nid))]
        homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(self.tgt_ntype, int(tgt_nid))]

        # Get the initial prediction.
        with torch.no_grad():
            score = self.model(src_nid, tgt_nid, ghetero, feat_nids)
            pred = (score > 0).int().item()

        if prune_graph:
            # The pruned graph for mask learning  
            ml_ghetero, etypes_to_pruned_ghetero_eid_masks = self._prune_graph(ghetero, 
                                                                               prune_max_degree,
                                                                               k_core,
                                                                               [homo_src_nid, homo_tgt_nid])
        else:
            # The original graph for mask learning  
            ml_ghetero = ghetero
            
        ml_edge_mask_dict = self._init_masks(ml_ghetero)
        optimizer = torch.optim.Adam(ml_edge_mask_dict.values(), lr=self.lr)
        
        if self.log:
            pbar = tqdm(total=self.num_epochs)

        eweight_norm = 0
        EPS = 1e-3
        for e in range(self.num_epochs):    
            
            # Apply sigmoid to edge_mask to get eweight
            ml_eweight_dict = {etype: ml_edge_mask_dict[etype].sigmoid() for etype in ml_edge_mask_dict}
        
            score = self.model(src_nid, tgt_nid, ml_ghetero, feat_nids, ml_eweight_dict)
            pred_loss = (-1) ** pred * score.sigmoid().log()
            self.all_loss['pred_loss'] += [pred_loss.item()]

            ml_ghetero.edata['eweight'] = ml_eweight_dict
            ml_ghomo = dgl.to_homogeneous(ml_ghetero, edata=['eweight'])
            ml_ghomo_eweights = ml_ghomo.edata['eweight']
            
            # Check for early stop
            curr_eweight_norm = ml_ghomo_eweights.norm()
            if abs(eweight_norm - curr_eweight_norm) < EPS:
                break
            eweight_norm = curr_eweight_norm
            
            # Update with path loss
            if with_path_loss:
                path_loss = self.path_loss(homo_src_nid, homo_tgt_nid, ml_ghomo, ml_ghomo_eweights)
            else: 
                path_loss = 0
            
            loss = pred_loss + path_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            self.all_loss['total_loss'] += [loss.item()]

            if self.log:
                pbar.update(1)

        if self.log:
            pbar.close()

        edge_mask_dict_placeholder = self._init_masks(ghetero)
        edge_mask_dict = {}
        
        if prune_graph:
            # remove pruned edges
            for etype in ghetero.canonical_etypes:
                edge_mask = edge_mask_dict_placeholder[etype].data + float('-inf')    
                pruned_ghetero_eid_mask = etypes_to_pruned_ghetero_eid_masks[etype]
                edge_mask[pruned_ghetero_eid_mask] = ml_edge_mask_dict[etype]
                edge_mask_dict[etype] = edge_mask
                
        else:
            edge_mask_dict = ml_edge_mask_dict
    
        edge_mask_dict = {k : v.detach() for k, v in edge_mask_dict.items()}
        return edge_mask_dict    

    def get_paths(self,
                  src_nid, 
                  tgt_nid, 
                  ghetero,
                  edge_mask_dict,
                  num_paths=1, 
                  max_path_length=3):

        """A postprocessing step that turns the `edge_mask_dict` into actual paths.
        
        Parameters
        ----------
        edge_mask_dict : dict
            key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges

        Others: see the `explain` method.
        
        Returns
        -------
        paths: list of lists
            each list contains (cannonical edge type, source node ids, target node ids)
        """
        ntype_pairs_to_cannonical_etypes = get_ntype_pairs_to_cannonical_etypes(ghetero)
        eweight_dict = {etype: edge_mask_dict[etype].sigmoid() for etype in edge_mask_dict}
        ghetero.edata['eweight'] = eweight_dict

        # convert ghetero to ghomo and find paths
        ghomo = dgl.to_homogeneous(ghetero, edata=['eweight'])
        ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)    
        homo_src_nid = ntype_hetero_nids_to_homo_nids[(self.src_ntype, int(src_nid))]
        homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(self.tgt_ntype, int(tgt_nid))]

        neg_path_score_func = get_neg_path_score_func(ghomo, 'eweight', [src_nid.item(), tgt_nid.item()])
        homo_paths = k_shortest_paths_with_max_length(ghomo, 
                                                       homo_src_nid, 
                                                       homo_tgt_nid,
                                                       weight=neg_path_score_func,
                                                       k=num_paths,
                                                       max_length=max_path_length)

        paths = []
        homo_nids_to_ntype_hetero_nids = get_homo_nids_to_ntype_hetero_nids(ghetero)
    
        if len(homo_paths) > 0:
            for homo_path in homo_paths:
                hetero_path = []
                for i in range(1, len(homo_path)):
                    homo_u, homo_v = homo_path[i-1], homo_path[i]
                    hetero_u_ntype, hetero_u_nid = homo_nids_to_ntype_hetero_nids[homo_u] 
                    hetero_v_ntype, hetero_v_nid = homo_nids_to_ntype_hetero_nids[homo_v] 
                    can_etype = ntype_pairs_to_cannonical_etypes[(hetero_u_ntype, hetero_v_ntype)]    
                    hetero_path += [(can_etype, hetero_u_nid, hetero_v_nid)]
                paths += [hetero_path]

        else:
            # A rare case, no paths found, take the top edges
            cat_edge_mask = torch.cat([v for v in edge_mask_dict.values()])
            M = len(cat_edge_mask)
            k = min(num_paths * max_path_length, M)
            threshold = cat_edge_mask.topk(k)[0][-1].item()
            path = []
            for etype in edge_mask_dict:
                u, v = ghetero.edges(etype=etype)  
                topk_edge_mask = edge_mask_dict[etype] >= threshold
                path += list(zip([etype] * topk_edge_mask.sum().item(), u[topk_edge_mask].tolist(), v[topk_edge_mask].tolist()))                
            paths = [path]
        return paths
    
    def explain(self,  
                src_nid, 
                tgt_nid, 
                ghetero,
                num_hops=2,
                prune_max_degree=-1,
                k_core=2, 
                num_paths=1, 
                max_path_length=3,
                prune_graph=True,
                with_path_loss=True,
                return_mask=False):
        
        """Return a path explanation of a predicted link
        
        Parameters
        ----------
        src_nid : int
            source node id

        tgt_nid : int
            target node id

        ghetero : dgl graph

        num_hops : int
            Number of hops to extract the computation graph, i.e. GNN # layers
            
        prune_max_degree : int
            If positive, prune the edges of graph nodes with degree larger than `prune_max_degree`
            If  -1, do nothing
            
        k_core : int 
            k for the the k-core graph extraction
            
        num_paths : int
            Number of paths for the postprocessing path extraction
            
        max_path_length : int
            Maximum length of paths for the postprocessing path extraction
        
        prune_graph : bool
            If true apply the max_degree and/or k-core pruning. For ablation. Default True.
            
        with_path_loss : bool
            If true include the path loss. For ablation. Default True.
            
        return_mask : bool
            If true return the edge mask in addition to the path. For AUC evaluation. Default False
        
        Returns
        -------
        paths: list of lists
            each list contains (cannonical edge type, source node ids, target node ids)

        (optional) edge_mask_dict : dict
            key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
        """
        # Extract the computation graph (k-hop subgraph)
        (comp_g_src_nid, 
         comp_g_tgt_nid, 
         comp_g, 
         comp_g_feat_nids) = hetero_src_tgt_khop_in_subgraph(self.src_ntype, 
                                                             src_nid, 
                                                             self.tgt_ntype, 
                                                             tgt_nid, 
                                                             ghetero, 
                                                             num_hops)
        # Learn the edge mask on the computation graph
        comp_g_edge_mask_dict = self.get_edge_mask(comp_g_src_nid, 
                                                   comp_g_tgt_nid, 
                                                   comp_g, 
                                                   comp_g_feat_nids,
                                                   prune_max_degree,
                                                   k_core,
                                                   prune_graph,
                                                   with_path_loss)

        # Extract paths 
        comp_g_paths = self.get_paths(comp_g_src_nid,
                                      comp_g_tgt_nid, 
                                      comp_g, 
                                      comp_g_edge_mask_dict, 
                                      num_paths, 
                                      max_path_length)    
        
        
        # Covert the node id in computation graph to original graph
        paths = comp_g_paths_to_paths(comp_g, comp_g_paths)
        
        if return_mask:
            # return masks for easier evaluation
            return paths, comp_g_edge_mask_dict
        else:
            return paths 

 model.py

import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import HeteroEmbedding, EdgePredictor

'''
HeteroRGCN model adapted from the DGL official tutorial
https://docs.dgl.ai/en/0.6.x/tutorials/basics/5_hetero.html
https://docs.dgl.ai/en/0.8.x/tutorials/models/1_gnn/4_rgcn.html
'''


class HeteroRGCNLayer(nn.Module):
    def __init__(self, in_size, out_size, etypes):
        super(HeteroRGCNLayer, self).__init__()
        # W_0 for transform the node's own feature
        self.weight0 = nn.Linear(in_size, out_size)
        
        # W_r for each relation
        self.weight = nn.ModuleDict({
                name : nn.Linear(in_size, out_size) for name in etypes
            })

    def forward(self, g, feat_dict, eweight_dict=None):
        # The input is a dictionary of node features for each type
        funcs = {}
        if eweight_dict is not None:
            # Store the sigmoid of edge weights
            g.edata['_edge_weight'] = eweight_dict
                
        for srctype, etype, dsttype in g.canonical_etypes:
            # Compute h_0 = W_0 * h
            h0 = self.weight0(feat_dict[srctype])
            g.nodes[srctype].data['h0'] = h0
            
            # Compute h_r = W_r * h
            Wh = self.weight[etype](feat_dict[srctype])
            # Save it in graph for message passing
            g.nodes[srctype].data['Wh_%s' % etype] = Wh
            # Specify per-relation message passing functions: (message_func, reduce_func).
            # Note that the results are saved to the same destination feature 'h', which
            # hints the type wise reducer for aggregation.
            if eweight_dict is not None:
                msg_fn = fn.u_mul_e('Wh_%s' % etype, '_edge_weight', 'm')
            else:
                msg_fn = fn.copy_u('Wh_%s' % etype, 'm')
                
            funcs[(srctype, etype, dsttype)] = (msg_fn, fn.mean('m', 'h'))

        def apply_func(nodes):
            h = nodes.data['h'] + nodes.data['h0']
            return {'h': h}
            
        # Trigger message passing of multiple types.
        # The first argument is the message passing functions for each relation.
        # The second one is the type wise reducer, could be "sum", "max",
        # "min", "mean", "stack"
        g.multi_update_all(funcs, 'sum', apply_func)
        # g.multi_update_all(funcs, 'sum')

        # return the updated node feature dictionary
        return {ntype : g.nodes[ntype].data['h'] for ntype in g.ntypes}


class HeteroRGCN(nn.Module):
    def __init__(self, g, emb_dim, hidden_size, out_size):
        super(HeteroRGCN, self).__init__()
        self.emb = HeteroEmbedding({ntype : g.num_nodes(ntype) for ntype in g.ntypes}, emb_dim)
        self.layer1 = HeteroRGCNLayer(emb_dim, hidden_size, g.etypes)
        self.layer2 = HeteroRGCNLayer(hidden_size, out_size, g.etypes)

    def forward(self, g, feat_nids=None, eweight_dict=None):
        if feat_nids is None:
            feat_dict = self.emb.weight
        else:
            feat_dict = self.emb(feat_nids)

        h_dict = self.layer1(g, feat_dict, eweight_dict)
        h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
        h_dict = self.layer2(g, h_dict, eweight_dict)
        return h_dict


class HeteroLinkPredictionModel(nn.Module):
    def __init__(self, encoder, src_ntype, tgt_ntype, link_pred_op='dot', **kwargs):
        super().__init__()
        self.encoder = encoder
        self.predictor = EdgePredictor(op=link_pred_op, **kwargs)
        self.src_ntype = src_ntype
        self.tgt_ntype = tgt_ntype

    def encode(self, g, feat_nids=None, eweight_dict=None):
        h = self.encoder(g, feat_nids, eweight_dict)
        return h

    def forward(self, src_nids, tgt_nids, g, feat_nids=None, eweight_dict=None):
        h = self.encode(g, feat_nids, eweight_dict)
        src_h = h[self.src_ntype][src_nids]
        tgt_h = h[self.tgt_ntype][tgt_nids]
        score = self.predictor(src_h, tgt_h).view(-1)
        return score

 pagelink.py

import os
import torch
import argparse
import pickle
from tqdm.auto import tqdm
from pathlib import Path

from utils import set_seed, print_args, set_config_args
from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel
from explainer import PaGELink


parser = argparse.ArgumentParser(description='Explain link predictor')
parser.add_argument('--device_id', type=int, default=-1)

'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1) 
parser.add_argument('--test_ratio', type=float, default=0.2)
parser.add_argument('--max_num_samples', type=int, default=-1, 
                    help='maximum number of samples to explain, for fast testing. Use all if -1')

'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)
parser.add_argument('--saved_model_dir', type=str, default='saved_models')
parser.add_argument('--saved_model_name', type=str, default='')

'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
                   help='operation passed to dgl.EdgePredictor')

'''
Explanation args
'''
parser.add_argument('--lr', type=float, default=0.01, help='explainer learning_rate') 
parser.add_argument('--alpha', type=float, default=1.0, help='explainer on-path edge regularizer weight') 
parser.add_argument('--beta', type=float, default=1.0, help='explainer off-path edge regularizer weight') 
parser.add_argument('--num_hops', type=int, default=2, help='computation graph number of hops') 
parser.add_argument('--num_epochs', type=int, default=20, help='How many epochs to learn the mask')
parser.add_argument('--num_paths', type=int, default=40, help='How many paths to generate')
parser.add_argument('--max_path_length', type=int, default=5, help='max lenght of generated paths')
parser.add_argument('--k_core', type=int, default=2, help='k for the k-core graph') 
parser.add_argument('--prune_max_degree', type=int, default=200,
                    help='prune the graph such that all nodes have degree smaller than max_degree. No prune if -1') 
parser.add_argument('--save_explanation', default=False, action='store_true', 
                    help='Whether to save the explanation')
parser.add_argument('--saved_explanation_dir', type=str, default='saved_explanations',
                    help='directory of saved explanations')
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')

args = parser.parse_args()

if args.config_path:
    args = set_config_args(args, args.config_path, args.dataset_name, 'pagelink')

if 'citation' in args.dataset_name:
    args.src_ntype = 'author'
    args.tgt_ntype = 'paper'

elif 'synthetic' in args.dataset_name:
    args.src_ntype = 'user'
    args.tgt_ntype = 'item'    

if torch.cuda.is_available() and args.device_id >= 0:
    device = torch.device('cuda', index=args.device_id)
else:
    device = torch.device('cpu')

if args.link_pred_op in ['cat']:
    pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
    pred_kwargs = {}
    
if not args.saved_model_name:
    args.saved_model_name = f'{args.dataset_name}_model'
    
print_args(args)
set_seed(0)

processed_g = load_dataset(args.dataset_dir, args.dataset_name, args.valid_ratio, args.test_ratio)[1]
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g.to(device) for g in processed_g]

encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)
state = torch.load(f'{args.saved_model_dir}/{args.saved_model_name}.pth', map_location='cpu')
model.load_state_dict(state)  

pagelink = PaGELink(model, 
                    lr=args.lr,
                    alpha=args.alpha, 
                    beta=args.beta, 
                    num_epochs=args.num_epochs,
                    log=True).to(device)


test_src_nids, test_tgt_nids = test_pos_g.edges()
test_ids = range(test_src_nids.shape[0])
if args.max_num_samples > 0:
    test_ids = test_ids[:args.max_num_samples]

pred_edge_to_comp_g_edge_mask = {}
pred_edge_to_paths = {}
for i in tqdm(test_ids):
    src_nid, tgt_nid = test_src_nids[i].unsqueeze(0), test_tgt_nids[i].unsqueeze(0)
    
    with torch.no_grad():
        pred = model(src_nid, tgt_nid, mp_g).sigmoid().item() > 0.5

    if pred:
        src_tgt = ((args.src_ntype, int(src_nid)), (args.tgt_ntype, int(tgt_nid)))
        paths, comp_g_edge_mask_dict = pagelink.explain(src_nid, 
                                                        tgt_nid, 
                                                        mp_g,
                                                        args.num_hops,
                                                        args.prune_max_degree,
                                                        args.k_core, 
                                                        args.num_paths, 
                                                        args.max_path_length,
                                                        return_mask=True)
        
        pred_edge_to_comp_g_edge_mask[src_tgt] = comp_g_edge_mask_dict 
        pred_edge_to_paths[src_tgt] = paths

if args.save_explanation:
    if not os.path.exists(args.saved_explanation_dir):
        os.makedirs(args.saved_explanation_dir)
        
    saved_edge_explanation_file = f'pagelink_{args.saved_model_name}_pred_edge_to_comp_g_edge_mask'
    saved_path_explanation_file = f'pagelink_{args.saved_model_name}_pred_edge_to_paths'
    pred_edge_to_comp_g_edge_mask = {edge: {k: v.cpu() for k, v in mask.items()} for edge, mask in pred_edge_to_comp_g_edge_mask.items()}

    saved_edge_explanation_path = Path.cwd().joinpath(args.saved_explanation_dir, saved_edge_explanation_file)
    with open(saved_edge_explanation_path, "wb") as f:
        pickle.dump(pred_edge_to_comp_g_edge_mask, f)

    saved_path_explanation_path = Path.cwd().joinpath(args.saved_explanation_dir, saved_path_explanation_file)
    with open(saved_path_explanation_path, "wb") as f:
        pickle.dump(pred_edge_to_paths, f)

 train_linkpred.py

import os
import torch
import torch.nn.functional as F
import copy
import argparse
from sklearn.metrics import roc_auc_score
from pathlib import Path
from utils import set_seed, negative_sampling, print_args, set_config_args
from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel

parser = argparse.ArgumentParser(description='Train a GNN-based link prediction model')
parser.add_argument('--device_id', type=int, default=-1)

'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1) 
parser.add_argument('--test_ratio', type=float, default=0.2)

'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)

'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
                   help='operation passed to dgl.EdgePredictor')
parser.add_argument('--lr', type=float, default=0.01, help='link predictor learning_rate') 
parser.add_argument('--num_epochs', type=int, default=200, help='How many epochs to train')
parser.add_argument('--eval_interval', type=int, default=1, help="Evaluate once per how many epochs")
parser.add_argument('--save_model', default=False, action='store_true', help='Whether to save the model')
parser.add_argument('--saved_model_dir', type=str, default='saved_models', help='Where to save the model')
parser.add_argument('--sample_neg_edges', default=False, action='store_true', 
                    help='If False, use fixed negative edges. If True, sample negative edges in each epoch')
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')

args = parser.parse_args()

if 'synthetic' in args.dataset_name:
    args.src_ntype = 'user'
    args.tgt_ntype = 'item'

elif 'citation' in args.dataset_name:
    args.src_ntype = 'author'
    args.tgt_ntype = 'paper'
    
if torch.cuda.is_available() and args.device_id >= 0:
    device = torch.device('cuda', index=args.device_id)
else:
    device = torch.device('cpu')

if args.link_pred_op in ['cat']:
    pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
    pred_kwargs = {}

if args.config_path:
    args = set_config_args(args, args.config_path, args.dataset_name, 'train_eval')
    
print_args(args)

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    device = scores.device
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).to(device)
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).detach().cpu().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

def run():
    set_seed(0)
    best_val_auc = 0
    pred_etype= args.pred_etype
    train_pos_src_nids, train_pos_tgt_nids = train_pos_g.edges(etype=pred_etype)            
    val_pos_src_nids, val_pos_tgt_nids = val_pos_g.edges(etype=pred_etype)            
    val_neg_src_nids, val_neg_tgt_nids = val_neg_g.edges(etype=pred_etype)            
    test_pos_src_nids, test_pos_tgt_nids = test_pos_g.edges(etype=pred_etype)            
    test_neg_src_nids, test_neg_tgt_nids = test_neg_g.edges(etype=pred_etype)            

    train_neg_src_nids, train_neg_tgt_nids = train_neg_g.edges(etype=pred_etype) 

    for epoch in range(1, args.num_epochs+1):
        train_pos_score = model(train_pos_src_nids, train_pos_tgt_nids, mp_g)   
        if args.sample_neg_edges:
            train_neg_src_nids, train_neg_tgt_nids = negative_sampling(train_pos_g, pred_etype) 
        train_neg_score = model(train_neg_src_nids, train_neg_tgt_nids, mp_g)
        loss = compute_loss(train_pos_score, train_neg_score)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % args.eval_interval == 0:
            with torch.no_grad():
                train_auc = compute_auc(train_pos_score, train_neg_score)
                val_pos_score = model(val_pos_src_nids, val_pos_tgt_nids, mp_g)
                val_neg_score = model(val_neg_src_nids, val_neg_tgt_nids, mp_g)
                val_auc = compute_auc(val_pos_score, val_neg_score)
                print('In epoch {}, loss: {:.4f}, train AUC: {:.4f}, val AUC: {:.4f}'.format(epoch, loss, train_auc, val_auc))
                if val_auc > best_val_auc:
                    best_epoch = epoch
                    best_val_auc = val_auc
                    state = copy.deepcopy(model.state_dict())

    with torch.no_grad():
        model.eval()
        model.load_state_dict(state)
        test_pos_score = model(test_pos_src_nids, test_pos_tgt_nids, mp_g)
        test_neg_score = model(test_neg_src_nids, test_neg_tgt_nids, mp_g)
        test_auc = compute_auc(test_pos_score, test_neg_score)
        print('Best epoch {}, val AUC: {:.4f}, test AUC: {:.4f}'.format(best_epoch, best_val_auc, test_auc))

processed_g = load_dataset(args.dataset_dir, args.dataset_name, args.valid_ratio, args.test_ratio)[1]
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g.to(device) for g in processed_g]

encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

run()

if args.save_model:
    output_dir = Path.cwd().joinpath(args.saved_model_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    torch.save(model.state_dict(), output_dir.joinpath(f"{args.dataset_name}_model.pth"))

 utils.py

import dgl
import torch
import random
import textwrap
import yaml
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
from dgl.subgraph import khop_in_subgraph
from itertools import count
from heapq import heappop, heappush
from sklearn.metrics import roc_auc_score

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def print_args(args):
    for k, v in vars(args).items():
        print(f'{k:25} {v}')
        
def set_config_args(args, config_path, dataset_name, model_name=''):
    with open(config_path, "r") as conf:
        config = yaml.load(conf, Loader=yaml.FullLoader)[dataset_name]
        if model_name:
            config = config[model_name]

    for key, value in config.items():
        setattr(args, key, value)
    return args
    
'''
Model training utils
'''
def idx_split(idx, ratio, seed=0):
    """
    Randomly split `idx` into idx1 and idx2, where idx1 : idx2 = `ratio` : 1 - `ratio`
    
    Parameters
    ----------
    idx : tensor
        
    ratio: float
 
    Returns
    ----------
        Two index (tensor) after split
    """
    set_seed(seed)
    n = len(idx)
    cut = int(n * ratio)
    idx_idx_shuffle = torch.randperm(n)

    idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:]
    idx1, idx2 = idx[idx1_idx], idx[idx2_idx]
    assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all())
    return idx1, idx2


def eids_split(eids, val_ratio, test_ratio, seed=0):
    """
    Split `eids` into three parts: train, valid, and test,
    where train : valid : test = (1 - `val_ratio` - `test_ratio`) : `val_ratio` : `test_ratio`
    
    Parameters
    ----------
    eid : tensor
        edge id
        
    val_ratio : float
    
    test_ratio : float

    seed : int

    Returns
    ----------
        Three edge ids (tensor) after split
    """
    train_ratio = (1 - val_ratio - test_ratio)
    train_eids, pred_eids = idx_split(eids, train_ratio, seed)
    val_eids, test_eids = idx_split(pred_eids, val_ratio / (1 - train_ratio), seed)
    return train_eids, val_eids, test_eids

def negative_sampling(graph, pred_etype=None, num_neg_samples=None):
    '''
    Adapted from PyG negative_sampling function
    https://pytorch-geometric.readthedocs.io/en/1.7.2/_modules/torch_geometric/utils/
    negative_sampling.html#negative_sampling

    Parameters
    ----------
    graph : dgl graph
    
    pred_etype : string
        The edge type for prediction

    num_neg_samples : int
    
    Returns
    ----------
        Two negative nids. Nids for src and tgt nodes of the `pred_etype`
    '''
    # src_N: total number of src nodes
    # N (tgt_N): total number of tgt nodes
    # M: total number of possible edges, square of src_N * tgt_N
    # pos_M: number of positive samples (observed edges)
    # neg_M: number of negative samples
    pos_src_nids, pos_tgt_nids = graph.edges(etype=pred_etype)
    if pred_etype is None:
        N = graph.num_nodes()
        M = N * N
    else:
        src_ntype, _, tgt_ntype = graph.to_canonical_etype(pred_etype) 
        src_N, N = graph.num_nodes(src_ntype), graph.num_nodes(tgt_ntype)
        M = src_N * N

    pos_M = pos_src_nids.shape[0]
    neg_M = num_neg_samples or pos_M
    neg_M = min(neg_M, M - pos_M) # incase M - pos_M < neg_M

    # Percentage of edges to opos_tgt_nidsersample, so only need to sample once in most cases
    alpha = abs(1 / (1 - 1.1 * (pos_M / M)))
    size = min(M, int(alpha * neg_M))
    perm = torch.tensor(random.sample(range(M), size))
    
    idx = pos_src_nids * N + pos_tgt_nids
    # mask = torch.from_npos_src_nidsmpy(np.isin(perm, idx.to('cppos_src_nids'))).to(torch.bool)
    mask = torch.isin(perm, idx.to('cpu')).to(torch.bool)
    perm = perm[~mask][:neg_M].to(pos_src_nids.device)

    neg_src_nids = torch.div(perm, N, rounding_mode='floor')
    neg_tgt_nids = perm % N

    return neg_src_nids, neg_tgt_nids

'''
DGL graph manipulation utils
'''
def get_homo_nids_to_hetero_nids(ghetero):
    '''
    Create a dictionary mapping the node ids of the homogeneous version of the input graph
    to the node ids of the input heterogeneous graph.
    
    Parameters
    ----------
    ghetero : heterogeneous dgl graph
        
    Returns
    ----------
    homo_nids_to_hetero_nids : dict
    '''
    ghomo = dgl.to_homogeneous(ghetero)
    homo_nids = range(ghomo.num_nodes())
    hetero_nids = ghomo.ndata[dgl.NID].tolist()
    homo_nids_to_hetero_nids = dict(zip(homo_nids, hetero_nids))
    return homo_nids_to_hetero_nids

def get_homo_nids_to_ntype_hetero_nids(ghetero):
    '''
    Create a dictionary mapping the node ids of the homogeneous version of the input graph
    to tuples as (node type, node id) of the input heterogeneous graph.
    
    Parameters
    ----------
    ghetero : heterogeneous dgl graph
        
    Returns
    ----------
    homo_nids_to_ntype_hetero_nids : dict
    '''
    ghomo = dgl.to_homogeneous(ghetero)
    homo_nids = range(ghomo.num_nodes())
    ntypes = ghetero.ntypes
    # This line relies on the default order of ntype_ids is the order in ghetero.ntypes
    ntypes = [ntypes[i] for i in ghomo.ndata[dgl.NTYPE]] 
    hetero_nids = ghomo.ndata[dgl.NID].tolist()
    ntypes_hetero_nids = list(zip(ntypes, hetero_nids))
    homo_nids_to_ntype_hetero_nids = dict(zip(homo_nids, ntypes_hetero_nids))
    return homo_nids_to_ntype_hetero_nids

def get_ntype_hetero_nids_to_homo_nids(ghetero):
    '''
    Create a dictionary mapping tuples as (node type, node id) of the input heterogeneous graph
    to the node ids of the homogeneous version of the input graph.
    
    Parameters
    ----------
    ghetero : heterogeneous dgl graph
        
    Returns
    ----------
    ntype_hetero_nids_to_homo_nids : dict
    '''
    tmp = get_homo_nids_to_ntype_hetero_nids(ghetero)
    ntype_hetero_nids_to_homo_nids = {v: k for k, v in tmp.items()}
    return ntype_hetero_nids_to_homo_nids

def get_ntype_pairs_to_cannonical_etypes(ghetero, pred_etype='likes'):
    '''
    Create a dictionary mapping tuples as (source node type, target node type) to 
    cannonical edge types. Edges wity type `pred_etype` will be excluded.
    A helper function for path finding.
    Only works if there is only one edge type between any pair of node types.
    
    Parameters
    ----------
    ghetero : heterogeneous dgl graph
      
    pred_etype : string
        The edge type for prediction

    Returns
    ----------
    ntype_pairs_to_cannonical_etypes : dict
    '''
    ntype_pairs_to_cannonical_etypes = {}
    for src_ntype, etype, tgt_ntype in ghetero.canonical_etypes:
        if etype != pred_etype:
            ntype_pairs_to_cannonical_etypes[(src_ntype, tgt_ntype)] = (src_ntype, etype, tgt_ntype)
    return ntype_pairs_to_cannonical_etypes

def get_num_nodes_dict(ghetero):
    '''
    Create a dictionary containing number of nodes of all ntypes in a heterogeneous graph
    Parameters
    ----------
    ghetero : heterogeneous dgl graph

    Returns 
    ----------
    num_nodes_dict : dict
        key=node type, value=number of nodes
    '''
    num_nodes_dict = {}
    for ntype in ghetero.ntypes:
        num_nodes_dict[ntype] = ghetero.num_nodes(ntype)    
    return num_nodes_dict

def remove_all_edges_of_etype(ghetero, etype):
    '''
    Remove all edges with type `etype` from `ghetero`. If `etype` is not in `ghetero`, do nothing.
    
    Parameters
    ----------
    ghetero : heterogeneous dgl graph

    etype : string or triple of strings
        Edge type in simple form (string) or cannonical form (triple of strings)
    
    Returns 
    ----------
    removed_ghetero : heterogeneous dgl graph
        
    '''
    etype = ghetero.to_canonical_etype(etype)
    if etype in ghetero.canonical_etypes:
        eids = ghetero.edges('eid', etype=etype)
        removed_ghetero = dgl.remove_edges(ghetero, eids, etype=etype)
    else:
        removed_ghetero = ghetero
    return removed_ghetero

def hetero_src_tgt_khop_in_subgraph(src_ntype, src_nid, tgt_ntype, tgt_nid, ghetero, k):
    '''
    Find the `k`-hop subgraph around the src node and tgt node in `ghetero`
    The output will be the union of two subgraphs.
    See the dgl `khop_in_subgraph` function as a referrence
    https://docs.dgl.ai/en/0.9.x/generated/dgl.khop_in_subgraph.html
    
    Parameters
    ----------
    src_ntype: string
        source node type
    
    src_nid : int
        source node id

    tgt_ntype: string
        target node type

    tgt_nid : int
        target node id

    ghetero : heterogeneous dgl graph

    k: int
        Number of hops

    Return
    ----------
    sghetero_src_nid: int
        id of the source node in the subgraph

    sghetero_tgt_nid: int
        id of the target node in the subgraph

    sghetero : heterogeneous dgl graph
        Union of two k-hop subgraphs

    sghetero_feat_nid: Tensor
        The original `ghetero` node ids of subgraph nodes, for feature identification
    
    '''
    # Extract k-hop subgraph centered at the (src, tgt) pair
    src_nid = src_nid.item() if torch.is_tensor(src_nid) else src_nid
    tgt_nid = tgt_nid.item() if torch.is_tensor(tgt_nid) else tgt_nid
    
    if src_ntype == tgt_ntype:
        pred_dict = {src_ntype: torch.tensor([src_nid, tgt_nid])}
        sghetero, inv_dict = khop_in_subgraph(ghetero, pred_dict, k)
        sghetero_src_nid = inv_dict[src_ntype][0]
        sghetero_tgt_nid = inv_dict[tgt_ntype][1]
    else:
        pred_dict = {src_ntype: src_nid, tgt_ntype: tgt_nid}
        sghetero, inv_dict = khop_in_subgraph(ghetero, pred_dict, k)
        sghetero_src_nid = inv_dict[src_ntype]
        sghetero_tgt_nid = inv_dict[tgt_ntype]

    sghetero_feat_nid = sghetero.ndata[dgl.NID]
    
    return sghetero_src_nid, sghetero_tgt_nid, sghetero, sghetero_feat_nid


'''
Path finding utils
'''
def get_neg_path_score_func(g, weight, exclude_node=[]):
    '''
    Compute the negative path score for the shortest path algorithm.
    
    Parameters
    ----------
    g : dgl graph

    weight: string
       The edge weights stored in g.edata

    exclude_node : iterable
        Degree of these nodes will be set to 0 when computing the path score, so they will likely be included.

    Returns
    ----------
    neg_path_score_func: callable function
       Takes in two node ids and return the edge weight. 
    '''
    log_eweights = g.edata[weight].log().tolist()
    log_in_degrees = g.in_degrees().log()
    log_in_degrees[exclude_node] = 0
    log_in_degrees = log_in_degrees.tolist()
    u, v = g.edges()
    neg_path_score_map = {edge : log_in_degrees[edge[1]] - log_eweights[i] for i, edge in enumerate(zip(u.tolist(), v.tolist()))}

    def neg_path_score_func(u, v):
        return neg_path_score_map[(u, v)]
    return neg_path_score_func

def bidirectional_dijkstra(g, src_nid, tgt_nid, weight=None, ignore_nodes=None, ignore_edges=None):
    """Dijkstra's algorithm for shortest paths using bidirectional search.
    
    Adapted from NetworkX _bidirectional_dijkstra
    https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html
    
    Parameters
    ----------
    g : dgl graph

    src_nid : int
        source node id

    tgt_nid : int
        target node id

    weight: callable function, optional 
       Takes in two node ids and return the edge weight. 

    ignore_nodes : container of nodes
       nodes to ignore, optional

    ignore_edges : container of edges
       edges to ignore, optional

    Returns
    -------
    length : number
        Shortest path length.

    """
    if src_nid == tgt_nid:
        return (0, [src_nid])

    src, tgt = g.edges()
    Gpred = lambda i: src[tgt == i].tolist()
    Gsucc = lambda i: tgt[src == i].tolist()
    
    if ignore_nodes:
        def filter_iter(nodes):
            def iterate(v):
                for w in nodes(v):
                    if w not in ignore_nodes:
                        yield w

            return iterate

        Gpred = filter_iter(Gpred)
        Gsucc = filter_iter(Gsucc)
    
    if ignore_edges:
        def filter_pred_iter(pred_iter):
            def iterate(v):
                for w in pred_iter(v):
                    if (w, v) not in ignore_edges:
                        yield w

            return iterate

        def filter_succ_iter(succ_iter):
            def iterate(v):
                for w in succ_iter(v):
                    if (v, w) not in ignore_edges:
                        yield w

            return iterate

        Gpred = filter_pred_iter(Gpred)
        Gsucc = filter_succ_iter(Gsucc)

    push = heappush
    pop = heappop
    # Init:   Forward             Backward
    dists = [{}, {}]  # dictionary of final distances
    paths = [{src_nid: [src_nid]}, {tgt_nid: [tgt_nid]}]  # dictionary of paths
    fringe = [[], []]  # heap of (distance, node) tuples for
    # extracting next node to expand
    seen = [{src_nid: 0}, {tgt_nid: 0}]  # dictionary of distances to
    # nodes seen
    c = count()
    # initialize fringe heap
    push(fringe[0], (0, next(c), src_nid))
    push(fringe[1], (0, next(c), tgt_nid))
    # neighs for extracting correct neighbor information
    neighs = [Gsucc, Gpred]
    # variables to hold shortest discovered path
    # finaldist = 1e30000
    finalpath = []
    dir = 1
    if not weight:
        weight = lambda u, v: 1
            
    while fringe[0] and fringe[1]:
        # choose direction
        # dir == 0 is forward direction and dir == 1 is back
        dir = 1 - dir
        # extract closest to expand
        (dist, _, v) = pop(fringe[dir])
        if v in dists[dir]:
            # Shortest path to v has already been found
            continue
        # update distance
        dists[dir][v] = dist  # equal to seen[dir][v]
        if v in dists[1 - dir]:
            # if we have scanned v in both directions we are done
            # we have now discovered the shortest path
            return (finaldist, finalpath)

        for w in neighs[dir](v):
            if dir == 0:  # forward
                minweight = weight(v, w)
                vwLength = dists[dir][v] + minweight
            else:  # back, must remember to change v,w->w,v
                minweight = weight(w, v)
                vwLength = dists[dir][v] + minweight

            if w in dists[dir]:
                if vwLength < dists[dir][w]:
                    raise ValueError("Contradictory paths found: negative weights?")
            elif w not in seen[dir] or vwLength < seen[dir][w]:
                # relaxing
                seen[dir][w] = vwLength
                push(fringe[dir], (vwLength, next(c), w))
                paths[dir][w] = paths[dir][v] + [w]
                if w in seen[0] and w in seen[1]:
                    # see if this path is better than the already
                    # discovered shortest path
                    totaldist = seen[0][w] + seen[1][w]
                    if finalpath == [] or finaldist > totaldist:
                        finaldist = totaldist
                        revpath = paths[1][w][:]
                        revpath.reverse()
                        finalpath = paths[0][w] + revpath[1:]
    raise ValueError("No paths found")


class PathBuffer:
    """For shortest paths finding
    
    Adapted from NetworkX shortest_simple_paths
    https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html

    """
    def __init__(self):
        self.paths = set()
        self.sortedpaths = list()
        self.counter = count()

    def __len__(self):
        return len(self.sortedpaths)

    def push(self, cost, path):
        hashable_path = tuple(path)
        if hashable_path not in self.paths:
            heappush(self.sortedpaths, (cost, next(self.counter), path))
            self.paths.add(hashable_path)

    def pop(self):
        (cost, num, path) = heappop(self.sortedpaths)
        hashable_path = tuple(path)
        self.paths.remove(hashable_path)
        return path
    
def k_shortest_paths_generator(g, 
                               src_nid, 
                               tgt_nid, 
                               weight=None, 
                               k=5, 
                               ignore_nodes_init=None,
                               ignore_edges_init=None):
    """Generate at most `k` simple paths in the graph g from src_nid to tgt_nid,
       each with maximum lenghth `max_length`, return starting from the shortest ones. 
       If a weighted shortest path search is to be used, no negative weights are allowed.

    Adapted from NetworkX shortest_simple_paths
    https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html

    Parameters
    ----------
    g : dgl graph

    src_nid : int
        source node id

    tgt_nid : int
        target node id

    weight: callable function, optional 
       Takes in two node ids and return the edge weight. 

    k: int
       number of paths
    
    ignore_nodes_init : set of nodes
       nodes to ignore, optional

    ignore_edges_init : set of edges
       edges to ignore, optional

    Returns
    -------
    path_generator: generator
       A generator that produces lists of tuples (path score, path), in order from
       shortest to longest. Each path is a list of node ids

    """
    if not weight:
        weight = lambda u, v: 1

    def length_func(path):
        return sum(weight(u, v) for (u, v) in zip(path, path[1:]))

    listA = list()
    listB = PathBuffer()
    prev_path = None
    while not prev_path or len(listA) < k:
        if not prev_path:
            length, path = bidirectional_dijkstra(g, src_nid, tgt_nid, weight, ignore_nodes_init, ignore_edges_init)
            listB.push(length, path)
        else:
            ignore_nodes = set(ignore_nodes_init) if ignore_nodes_init else set()
            ignore_edges = set(ignore_edges_init) if ignore_edges_init else set()
            for i in range(1, len(prev_path)):
                root = prev_path[:i]
                root_length = length_func(root)
                for path in listA:
                    if path[:i] == root:
                        ignore_edges.add((path[i - 1], path[i]))
                try:
                    length, spur = bidirectional_dijkstra(g,
                                                          root[-1],
                                                          tgt_nid,
                                                          ignore_nodes=ignore_nodes,
                                                          ignore_edges=ignore_edges,
                                                          weight=weight)
                    path = root[:-1] + spur
                    listB.push(root_length + length, path)
                except ValueError:
                    pass
                ignore_nodes.add(root[-1])
        
        if listB:
            path = listB.pop()
            yield path
            listA.append(path)
            prev_path = path
        else:
            break

def k_shortest_paths_with_max_length(g, 
                                     src_nid, 
                                     tgt_nid, 
                                     weight=None, 
                                     k=5, 
                                     max_length=None,
                                     ignore_nodes=None,
                                     ignore_edges=None):
    
    """Generate at most `k` simple paths in the graph g from src_nid to tgt_nid,
       each with maximum lenghth `max_length`, return starting from the shortest ones. 
       If a weighted shortest path search is to be used, no negative weights are allowed.
   
    Parameters
    ----------
       See function `k_shortest_paths_generator`
   
    Return
    -------
    paths: list of lists
       Each list is a path containing node ids
    """
    path_generator = k_shortest_paths_generator(g, 
                                                src_nid, 
                                                tgt_nid, 
                                                weight=weight,
                                                k=k, 
                                                ignore_nodes_init=ignore_nodes,
                                                ignore_edges_init=ignore_edges)
    
    try:
        if max_length:
            paths = [path for path in path_generator if len(path) <= max_length + 1]
        else:
            paths = list(path_generator)

    except ValueError:
        paths = [[]]

    return paths

'''
Evaluation utils
'''
def get_comp_g_edge_labels(comp_g, edge_labels):
    """Turn `edge_labels` with node ids in the original graph to
       `comp_g_edge_labels` with node ids in the computation graph.
       For easier evaluation.

    Parameters
    ----------
    comp_g : heterogeneous dgl graph
        computation graph, with .ndata stores key dgl.NID
    
    edge_labels : dict
        key=edge type, value=(source node ids, target node ids)
   
    Return
    -------
    comp_g_edge_labels: dict
        key=edge type, value=a tensor of labels, each label is in {0, 1}
    """
    ntype_to_tensor_nids_to_comp_g_nids = {}
    ntypes_to_comp_g_max_nids = {}
    ntypes_to_nids = comp_g.ndata[dgl.NID]
    for ntype in ntypes_to_nids.keys():
        nids = ntypes_to_nids[ntype]
        if nids.numel() > 0:
            max_nid = nids.max().item()
        else: 
            max_nid = -1

        ntypes_to_comp_g_max_nids[ntype] = max_nid

        nids_to_comp_g_nids = torch.zeros(max_nid + 1).long() - 1
        # The i-th entry will be the nid in comp_g for the i-th node in g
        nids_to_comp_g_nids[nids] = torch.arange(nids.shape[0])
        ntype_to_tensor_nids_to_comp_g_nids[ntype] = nids_to_comp_g_nids


    comp_g_edge_labels = {}
    for can_etype in edge_labels:
        start_ntype, etype, end_ntype = can_etype
        start_nids, end_nids = edge_labels[can_etype]
        start_comp_g_max_nid, end_comp_g_max_nid = ntypes_to_comp_g_max_nids[start_ntype], ntypes_to_comp_g_max_nids[end_ntype]

        # For edges in label but not in comp_g, exclude them
        start_included_nid_mask = start_nids <= start_comp_g_max_nid
        end_included_nid_mask = end_nids <= end_comp_g_max_nid
        comp_g_included_nid_mask = end_included_nid_mask & start_included_nid_mask

        start_nids = start_nids[comp_g_included_nid_mask]
        end_nids = end_nids[comp_g_included_nid_mask]

        comp_g_start_nids = ntype_to_tensor_nids_to_comp_g_nids[start_ntype][start_nids]
        comp_g_end_nids = ntype_to_tensor_nids_to_comp_g_nids[end_ntype][end_nids]
        comp_g_eids = comp_g.edge_ids(comp_g_start_nids.tolist(), comp_g_end_nids.tolist(), etype=etype)

        num_edges = comp_g.num_edges(etype=can_etype)
        comp_g_eid_mask = torch.zeros(num_edges)
        comp_g_eid_mask[comp_g_eids] = 1

        comp_g_edge_labels[can_etype] = comp_g_eid_mask

    return comp_g_edge_labels    

def get_comp_g_path_labels(comp_g, path_labels):
    """Turn `path_labels` with node ids in the original graph
       `comp_g_path_labels` with node ids in the computation graph
       For easier evaluation.

    Parameters
    ----------
    comp_g : heterogeneous dgl graph
        computation graph, with .ndata stores key dgl.NID
    
    path_labels : list of lists
        Each list is a path, i.e., triples of 
        (cannonical edge type, source node id, target node id)
   
    Returns
    -------
    comp_g_path_labels: list of lists
        Each list is a path, i.e., tuples of (cannonical edge type, edge id)
    """
    ntype_to_tensor_nids_to_comp_g_nids = {}
    ntypes_to_comp_g_max_nids = {}
    ntypes_to_nids = comp_g.ndata[dgl.NID]
    for ntype in ntypes_to_nids.keys():
        nids = ntypes_to_nids[ntype]
        if nids.numel() > 0:
            max_nid = nids.max().item()
        else: 
            max_nid = -1

        ntypes_to_comp_g_max_nids[ntype] = max_nid

        nids_to_comp_g_nids = torch.zeros(max_nid + 1).long() - 1
        # The i-th entry will be the nid in comp_g for the i-th node in g
        nids_to_comp_g_nids[nids] = torch.arange(nids.shape[0])
        ntype_to_tensor_nids_to_comp_g_nids[ntype] = nids_to_comp_g_nids

    comp_g_path_labels = []
    for path in path_labels:
        comp_g_path = []
        for can_etype, start_nid, end_nid in path:
            start_ntype, etype, end_ntype = can_etype

            comp_g_start_nid = ntype_to_tensor_nids_to_comp_g_nids[start_ntype][start_nid].item()
            comp_g_end_nid = ntype_to_tensor_nids_to_comp_g_nids[end_ntype][end_nid].item()

            comp_g_eid = comp_g.edge_ids(comp_g_start_nid, comp_g_end_nid, etype=can_etype)
            comp_g_path += [(can_etype, comp_g_eid)]
        comp_g_path_labels += [comp_g_path]
    return comp_g_path_labels

def eval_edge_mask_auc(edge_mask_dict, edge_labels):
    '''
    Evaluate the AUC of an edge mask
    
    Parameters
    ----------
    edge_mask_dict: dict
        key=edge type, value=a tensor of labels, each label is in (-inf, inf)

    edge_labels: dict
        key=edge type, value=a tensor of labels, each label is in {0, 1}

    Returns
    ----------
    ROC-AUC score : int
    '''
    
    y_true = []
    y_score = []
    for can_etype in edge_labels:
        y_true += [edge_labels[can_etype]]
        y_score += [edge_mask_dict[can_etype].detach().sigmoid()]

    y_true = torch.cat(y_true)
    y_score = torch.cat(y_score)
    
    return roc_auc_score(y_true, y_score) 

def eval_edge_mask_topk_path_hit(edge_mask_dict, path_labels, topks=[10]):
    '''
    Evaluate the path hit rate of the top k edges in an edge mask
    
    Parameters
    ----------
    edge_mask_dict: dict
        key=edge type, value=a tensor of labels, each label is in (-inf, inf)

    path_labels: list of lists
        Each list is a path, i.e., tuples of (cannonical edge type, edge id)

    topks: iterable
        An iterable of the top `k` values. Each `k` determines how many edges to select 
        from the top values of the mask.

    Returns
    ----------
    topk_to_path_hit: dict
        Mapping the top `k` to 
    '''
    cat_edge_mask = torch.cat([v for v in edge_mask_dict.values()])
    M = len(cat_edge_mask)
    topks = {k: min(k, M) for k in topks}

    topk_to_path_hit = defaultdict(list)
    for r, k in topks.items():
        threshold = cat_edge_mask.topk(k)[0][-1].item()
        hard_edge_mask_dict = {}
        for etype in edge_mask_dict:
            hard_edge_mask_dict[etype] = edge_mask_dict[etype] >= threshold

        hit = eval_hard_edge_mask_path_hit(hard_edge_mask_dict, path_labels)
        topk_to_path_hit[r] += [hit]
    return topk_to_path_hit

def eval_hard_edge_mask_path_hit(hard_edge_mask_dict, path_labels):
    '''
    Evaluate the path hit of the an hard edge mask
    
    Parameters
    ----------
    hard_edge_mask_dict: dict
        key=edge type, value=a tensor of labels, each label is in {True, False}

    path_labels: list of lists
        Each list is a path, i.e., tuples of (cannonical edge type, edge id)

    Returns
    ----------
    hit_path: int
        1 or 0
    '''
    for path in path_labels:
        hit_path = 1
        for can_etype, eid in path:
            if not hard_edge_mask_dict[can_etype][eid]:
                hit_path = 0
                break
        if hit_path:
            return 1
    return 0


def eval_path_explanation_edges_path_hit(path_explanation_edges, path_labels):
    '''
    Evaluate the path hit rate of the a path_explanation_edges
    
    Parameters
    ----------
    path_explanation_edges : list
        Edges on the path explanation, each edge is a triples of 
        (cannonical edge type, source node id, target node id)
    
    path_labels : list of lists
        Each list is a path, i.e., triples of 
        (cannonical edge type, source node id, target node id)

    Returns
    ----------
    hit_path: int
        1 or 0
    '''
    for path in path_labels:
        hit_path = 1
        for edge in path:
            if edge not in path_explanation_edges:
                hit_path = 0
                break
        if hit_path:
            return 1
    return 0


'''
Plotting utils
'''
def plot_hetero_graph(ghetero,
                      ntypes_to_nshapes=None,
                      ntypes_to_ncolors=None,
                      ntypes_to_nlayers=None,
                      layout='multipartite',
                      layout_seed=0,
                      node_size=1000,
                      edge_kwargs={},
                      selected_node_dict=None,
                      selected_node_color='red',
                      selected_edge_dict=None,
                      selected_edge_kwargs={},
                      label='nid',
                      etype_label=True,
                      label_offset=False,
                      title=None,
                      legend=False,
                      figsize=(10, 10),
                      fig_name=None,
                      fig_format='png',
                      is_show=True):
        '''
        Parameters
        ----------
        ghetero: a DGL heterogeneous graph with ndata `order`

        ntypes_to_nshapes : Dict
            mapping node types to node shapes
        
        ntypes_to_ncolors : Dict
            mapping node types to node colors

        ntypes_to_nlayers : Dict 
            mapping node types to layer order in the multipartite layout. 

        label: String
            one of ['none', nid'] or a node feature stored in ndata of ghetero

        Returns
        ----------
        nx_graph : networkx graph
        
        '''
        if ntypes_to_nshapes is None:
            default_node_shape = 'o'
        if ntypes_to_ncolors is None:
            default_node_color = 'cyan'
        if selected_node_dict is not None:
            selected_node_dict = {ntype: list(selected_node_dict[ntype]) for ntype in selected_node_dict}

        # Convert DGL graph to networkx graph
        ghomo = dgl.to_homogeneous(ghetero)
        edges = torch.cat([t.unsqueeze(1) for t in ghomo.edges()], dim=1)
        edge_list = [(n_frm, n_to) for (n_frm, n_to) in edges.tolist()]
        nx_graph = dgl.to_networkx(ghomo, node_attrs=[dgl.NTYPE])
            
        # Use different layout
        if layout == 'spring':
            pos = nx.spring_layout(nx_graph, seed=layout_seed)
        elif layout == 'kk':
            pos = nx.kamada_kawai_layout(nx_graph)
        elif layout == 'multipartite':
            if ntypes_to_nlayers is not None:
                ntype_ids_to_nlayers = {ghetero.get_ntype_id(ntype): ntypes_to_nlayers[ntype] for ntype in ghetero.ntypes}
            else:
                ntype_ids_to_nlayers = {ghetero.get_ntype_id(ntype): i for i, ntype in enumerate(ghetero.ntypes)}
                
            for i in nx_graph.nodes():
                ntype_id = nx_graph.nodes()[i][dgl.NTYPE].item()
                nx_graph.nodes()[i][dgl.NTYPE] = ntype_ids_to_nlayers[ntype_id]

            pos = nx.multipartite_layout(nx_graph, subset_key=dgl.ETYPE, scale=1)
        else:
            raise ValueError('Unknown layout')

        # Start drawing
        plt.figure(figsize=figsize)
        ax = plt.gca()
 
        # Draw nodes for each ntype
        for ntype in ghetero.ntypes:
            ntype_ids = ghomo.ndata[dgl.NTYPE]
            hetero_nids = ghomo.ndata[dgl.NID] # nid in the original hetero graph
            
            node_shape = ntypes_to_nshapes[ntype] if ntypes_to_nshapes else default_node_shape
            node_color = ntypes_to_ncolors[ntype] if ntypes_to_ncolors else default_node_color

            # For the current node type, get the node type id and node ids
            curr_ntype_id = ghetero.get_ntype_id(ntype)
            curr_nids_mask = ntype_ids == curr_ntype_id
            curr_nids = curr_nids_mask.nonzero().view(-1).tolist()

            # For the current node type, get node ids and prediction node id in the original hetero graph
            curr_hetero_nids = hetero_nids[curr_nids_mask]
            
            if selected_node_dict is not None:
                curr_hetero_selected_nid = selected_node_dict.get(ntype)
                if curr_hetero_selected_nid is not None:
                    curr_node_color = []
                    for hetero_nid in curr_hetero_nids:
                        curr_node_color += [selected_node_color if hetero_nid in curr_hetero_selected_nid else node_color]
                    node_color = curr_node_color

            nx.draw_networkx_nodes(nx_graph, 
                                   pos, 
                                   curr_nids, 
                                   node_shape=node_shape,
                                   node_color=node_color,
                                   node_size=node_size,
                                   ax=ax)
            
        # Draw edges
        nx.draw_networkx_edges(nx_graph, pos, edge_list, **edge_kwargs, ax=ax)
        
        if selected_edge_dict is not None:
            ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)
            homo_selected_edge_list = []
            for etype in selected_edge_dict:
                src_ntype, _, tgt_ntype = ghetero.to_canonical_etype(etype)
                src_nids, tgt_nids = selected_edge_dict[etype]
                for src_nid, tgt_nid in zip(src_nids.tolist(), tgt_nids.tolist()):
                    homo_src_nid = ntype_hetero_nids_to_homo_nids[(src_ntype, src_nid)]
                    homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(tgt_ntype, tgt_nid)]
                    homo_selected_edge_list += [(homo_src_nid, homo_tgt_nid)]
        
            nx.draw_networkx_edges(nx_graph, pos, homo_selected_edge_list, **selected_edge_kwargs, ax=ax)
            
            
     # Start labelling nodes
        if label == 'none':
            pass
        elif label == 'nid':
            homo_nids_to_hetero_nids = get_homo_nids_to_hetero_nids(ghetero)
            nx.draw_networkx_labels(nx_graph, pos, labels=homo_nids_to_hetero_nids)
        else:
            # Set extra space to avoid label outside of the box
            x_values, y_values = zip(*pos.values())
            x_max = max(x_values)
            x_min = min(x_values)
            x_margin = (x_max - x_min) * 0.12
            ax.set_xlim(x_min - x_margin, x_max + x_margin)


            if ghetero.ndata.get(label):
                homo_nids_to_hetero_ndata_feat = get_homo_nids_to_hetero_ntype_data_feat(ghetero, label)
                if label_offset:
                    offset = 0.8 / figsize[1]
                    label_pos = {nid : [p[0], p[1] - offset] for nid, p in pos.items()} 
                else:
                    label_pos = pos

                nx.draw_networkx_labels(nx_graph, 
                                        label_pos, 
                                        font_size=14, 
                                        font_weight='bold', 
                                        labels=homo_nids_to_hetero_ndata_feat,
                                        horizontalalignment='center',
                                        verticalalignment='center',
                                        ax=ax)

            else:
                raise ValueError('Unrecognized label')
            
        # Start labelling edges with etype
        if etype_label is not None:
            if ghetero.ndata.get(label):
                homo_nid_pairs_to_etypes = get_homo_nid_pairs_to_etypes(ghetero)
                nx.draw_networkx_edge_labels(nx_graph, 
                                             pos, 
                                             font_size=13, 
                                             font_weight='bold', 
                                             edge_labels=homo_nid_pairs_to_etypes,
                                             horizontalalignment='center',
                                             verticalalignment='center',
                                             ax=ax)
            
        if legend:
            plt.legend(ghetero.ntypes, fontsize=15, prop={'size': figsize[0]*2.5}, bbox_to_anchor = (1.15, 0.7)) 

        ax.axis('off')
        if title is not None:
            plt.title(textwrap.fill(title, width=60))
        if fig_name is not None:
            plt.savefig(fig_name, format=fig_format, bbox_inches='tight')
        if is_show:
            plt.show()
        if fig_name is not None:
            plt.close()
            
        return nx_graph
  
def get_homo_nids_to_hetero_ntype_data_feat(ghetero, feat=dgl.NID):
    '''
    Plotting helper function
    '''
    ghomo = dgl.to_homogeneous(ghetero)
    homo_nids = range(ghomo.num_nodes())
    hetero_ndata_feat = []
    for ntype in ghetero.ntypes:
        hetero_ndata_feat += [f'{ntype[0]}' + f'{feat}' for feat in ghetero.ndata[feat][ntype].tolist()]

    homo_nids_to_hetero_ndata_feat = dict(zip(homo_nids, hetero_ndata_feat))
    return homo_nids_to_hetero_ndata_feat

def get_homo_nid_pairs_to_etypes(ghetero):
    '''
    Plotting helper function
    '''
    ghomo = dgl.to_homogeneous(ghetero)
    etypes = ghetero.etypes
    etype_list = [etypes[etype_id] for etype_id in ghomo.edata[dgl.ETYPE]]
    u, v = ghomo.edges()
    homo_nid_pairs_to_etypes = dict(zip(zip(u.tolist(), v.tolist()), etype_list))
    return homo_nid_pairs_to_etypes

 代码执行流程:首先需要训练边预测模型,其次再去生成解释模型。

论文中实验举例是预测的不同类型节点的链接,当进行相同类型节点链接的时候,还需要进行代码的微调。其次,文中节点的边都是双向边,当加入单向链接的时候,数据处理也需要微调一下哦。祝大家好好学习,天天向上。

  • 17
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值