《PaGE-Link: Path-based Graph Neural Network Explanation for Heterogeneous Link Prediction》
1. 论文概要
作者提出了用于异构图链路预测(边预测)的基于路径的GNN解释(PaGE-Link),通俗易懂的讲就是解释一下为什么预测的边是存在的,让大家得以信服(因为黑盒理论,很多时候大家对神经网络的结果不认可,原因是缺乏解释性)。
该方法生成预测连接的解释, 具有模型可扩展性,并且可以用于异构图处理。PaGE-Link可以作为连接节点对的路径生成的解释,首先预测两个节点之间的连接,其次生成相关解释。
个人感觉就是把重要的边连接的依据特征给筛选出来,生成一堆非直接路径作为两个节点的链接依据。
图中做了相关举例:针对user1和item1的链接,理论依据的路径1是u1 ->i2->a1->i1,路径2是u1 ->i3->u2->i1。其实还有很多依据,只不过模型通过学习给出了两个最可能的依据。这个数量是自己可以设定的。
鉴于GNN的解释对LP的重要性和挑战,作者将其表述为一个事后和实例级的解释问题,并以连接源节点和目标节点的重要路径的形式对其进行解释。
相关的工作,自己可以看一下相关方便的其它研究。
提出的方法主要有两部分:K剪枝,这个是为了减小模型的搜索规模,减小时间复杂度;掩码学习,这个主要是探索权重比较高的边,以作为路径选择的特征依据 。
模型解释的最终目的是提高模型透明度,帮助人类决策。因此, 人工评估是评估解释器有效性的最佳方式,这在之前的工作中一 直是标准的评估方法。我们通过从AugCitation的测试集 中随机选取100个预测链接进行人工评估,并使用GNNExp-Link、 PGExp-Link和PaGE-Link为每个链接生成解释。我们设计了一个带有单项选择题的调查。在每个问题中,我们用图结构和节点/边类型信息向受访者展示预测的链接和这三种解释,类似于图5, 但不包括方法名称。该调查被发送给研究生、博士后、工程师、 研究科学家和教授,包括具有和不具有GNNs背景知识的人。我们要求受访者"请选择最佳解释‘为什么模型预测这个作者会喜欢推荐的论文?’”。每个问题至少收集了来自不同人的三个答案。总共收集了340个评价,其中78.79%的评价认为PaGE-Link的解释是最好的。
2. 开源代码
data_processing.py
数据处理过程
import dgl
import torch
import scipy.sparse as sp
import numpy as np
from utils import eids_split, remove_all_edges_of_etype, get_num_nodes_dict
def process_data(g,
val_ratio,
test_ratio,
src_ntype = 'author',
tgt_ntype = 'paper',
pred_etype = 'likes',
neg='pred_etype_neg'):
'''
Parameters
----------
g : dgl graph
val_ratio : float
test_ratio : float
src_ntype: string
source node type
tgt_ntype: string
target node type
neg: string
One of ['pred_etype_neg', 'src_tgt_neg'], different negative sampling modes. See below.
Returns
----------
mp_g:
graph for message passing.
graphs containing positive edges and negative edges for train, valid, and test
'''
u, v = g.edges(etype=pred_etype)
src_N = g.num_nodes(src_ntype)
tgt_N = g.num_nodes(tgt_ntype)
M = u.shape[0] # number of directed edges
eids = torch.arange(M)
train_pos_eids, val_pos_eids, test_pos_eids = eids_split(eids, val_ratio, test_ratio)
train_pos_u, train_pos_v = u[train_pos_eids], v[train_pos_eids]
val_pos_u, val_pos_v = u[val_pos_eids], v[val_pos_eids]
test_pos_u, test_pos_v = u[test_pos_eids], v[test_pos_eids]
if neg == 'pred_etype_neg':
# Edges not in pred_etype as negative edges
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())), shape=(src_N, tgt_N))
adj_neg = 1 - adj.todense()
neg_u, neg_v = np.where(adj_neg != 0)
elif neg == 'src_tgt_neg':
# Edges not connecting src and tgt as negative edges
# Collect all edges between the src and tgt
src_tgt_indices = []
for etype in g.canonical_etypes:
if etype[0] == src_ntype and etype[2] == tgt_ntype:
adj = g.adj(etype=etype)
src_tgt_index = adj.coalesce().indices()
src_tgt_indices += [src_tgt_index]
src_tgt_u, src_tgt_v = torch.cat(src_tgt_indices, dim=1)
# Find all negative edges that are not in src_tgt_indices
adj = sp.coo_matrix((np.ones(len(src_tgt_u)), (src_tgt_u.numpy(), src_tgt_v.numpy())), shape=(src_N, tgt_N))
adj_neg = 1 - adj.todense()
neg_u, neg_v = np.where(adj_neg != 0)
else:
raise ValueError('Unknow negative argument')
neg_eids = np.random.choice(neg_u.shape[0], min(neg_u.shape[0], M), replace=False)
train_neg_eids, val_neg_eids, test_neg_eids = eids_split(torch.from_numpy(neg_eids), val_ratio, test_ratio)
# train_neg_u, train_neg_v = neg_u[train_neg_eids], neg_v[train_neg_eids]
# val_neg_u, val_neg_v = neg_u[val_neg_eids], neg_v[val_neg_eids]
# test_neg_u, test_neg_v = neg_u[test_neg_eids], neg_v[test_neg_eids]
# Avoid losing dimension in single number slicing
train_neg_u, train_neg_v = np.take(neg_u, train_neg_eids), np.take(neg_v, train_neg_eids)
val_neg_u, val_neg_v = np.take(neg_u, val_neg_eids),np.take(neg_v, val_neg_eids)
test_neg_u, test_neg_v = np.take(neg_u, test_neg_eids), np.take(neg_v, test_neg_eids)
# Construct graphs
pred_can_etype = (src_ntype, pred_etype, tgt_ntype)
num_nodes_dict = get_num_nodes_dict(g)
train_pos_g = dgl.heterograph({pred_can_etype: (train_pos_u, train_pos_v)}, num_nodes_dict)
train_neg_g = dgl.heterograph({pred_can_etype: (train_neg_u, train_neg_v)}, num_nodes_dict)
val_pos_g = dgl.heterograph({pred_can_etype: (val_pos_u, val_pos_v)}, num_nodes_dict)
val_neg_g = dgl.heterograph({pred_can_etype: (val_neg_u, val_neg_v)}, num_nodes_dict)
test_pos_g = dgl.heterograph({pred_can_etype: (test_pos_u, test_pos_v)}, num_nodes_dict)
test_neg_g = dgl.heterograph({pred_can_etype: (test_neg_u, test_neg_v)}, num_nodes_dict)
mp_g = remove_all_edges_of_etype(g, pred_etype) # Remove pred_etype edges but keep nodes
return mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g
def load_dataset(dataset_dir, dataset_name, val_ratio, test_ratio):
'''
Parameters
----------
dataset_dir : string
dataset directory
dataset_name : string
val_ratio : float
test_ratio : float
Returns:
----------
g: dgl graph
The original graph
processed_g: tuple of seven dgl graphs
The outputs of the function `process_data`,
which includes g for message passing, train, valid, and test
pred_pair_to_edge_labels : dict
key=((source node type, source node id), (target node type, target node id))
value=dict, {cannonical edge type: (source node ids, target node ids)}
pred_pair_to_path_labels : dict
key=((source node type, source node id), (target node type, target node id))
value=list of lists, each list contains (cannonical edge type, source node ids, target node ids)
'''
graph_saving_path = f'{dataset_dir}/{dataset_name}'
graph_list, _ = dgl.load_graphs(graph_saving_path)
pred_pair_to_edge_labels = torch.load(f'{graph_saving_path}_pred_pair_to_edge_labels')
pred_pair_to_path_labels = torch.load(f'{graph_saving_path}_pred_pair_to_path_labels')
g = graph_list[0]
if 'synthetic' in dataset_name:
src_ntype, tgt_ntype = 'user', 'item'
elif 'citation' in dataset_name:
src_ntype, tgt_ntype = 'author', 'paper'
pred_etype = 'likes'
neg = 'src_tgt_neg'
processed_g = process_data(g, val_ratio, test_ratio, src_ntype, tgt_ntype, pred_etype, neg)
return g, processed_g, pred_pair_to_edge_labels, pred_pair_to_path_labels
eval_explanations.py
import torch
import numpy as np
import argparse
import pickle
from collections import defaultdict
from pathlib import Path
from tqdm.auto import tqdm
from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel
from utils import set_config_args, get_comp_g_edge_labels, get_comp_g_path_labels
from utils import hetero_src_tgt_khop_in_subgraph, eval_edge_mask_auc, eval_edge_mask_topk_path_hit
parser = argparse.ArgumentParser(description='Explain link predictor')
'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--test_ratio', type=float, default=0.2)
parser.add_argument('--max_num_samples', type=int, default=-1,
help='maximum number of samples to explain, for fast testing. Use all if -1')
'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)
parser.add_argument('--saved_model_dir', type=str, default='saved_models')
parser.add_argument('--saved_model_name', type=str, default='')
'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
help='operation passed to dgl.EdgePredictor')
'''
Explanation args
'''
parser.add_argument('--num_hops', type=int, default=2, help='computation graph number of hops')
parser.add_argument('--saved_explanation_dir', type=str, default='saved_explanations',
help='directory of saved explanations')
parser.add_argument('--eval_explainer_names', nargs='+', default=['pagelink'],
help='name of explainers to evaluate')
parser.add_argument('--eval_path_hit', default=False, action='store_true',
help='Whether to save the explanation')
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')
args = parser.parse_args()
if args.config_path:
args = set_config_args(args, args.config_path, args.dataset_name, 'train_eval')
if 'citation' in args.dataset_name:
args.src_ntype = 'author'
args.tgt_ntype = 'paper'
elif 'synthetic' in args.dataset_name:
args.src_ntype = 'user'
args.tgt_ntype = 'item'
if args.link_pred_op in ['cat']:
pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
pred_kwargs = {}
g, processed_g, pred_pair_to_edge_labels, pred_pair_to_path_labels = load_dataset(args.dataset_dir,
args.dataset_name,
args.valid_ratio,
args.test_ratio)
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g for g in processed_g]
encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)
if not args.saved_model_name:
args.saved_model_name = f'{args.dataset_name}_model'
state = torch.load(f'{args.saved_model_dir}/{args.saved_model_name}.pth', map_location='cpu')
model.load_state_dict(state)
test_src_nids, test_tgt_nids = test_pos_g.edges()
comp_graphs = defaultdict(list)
comp_g_labels = defaultdict(list)
test_ids = range(test_src_nids.shape[0])
if args.max_num_samples > 0:
test_ids = test_ids[:args.max_num_samples]
for i in tqdm(test_ids):
# Get the k-hop subgraph
src_nid, tgt_nid = test_src_nids[i], test_tgt_nids[i]
comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids = hetero_src_tgt_khop_in_subgraph(args.src_ntype,
src_nid,
args.tgt_ntype,
tgt_nid,
mp_g,
args.num_hops)
with torch.no_grad():
pred = model(comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids).sigmoid().item() > 0.5
if pred:
src_tgt = ((args.src_ntype, int(src_nid)), (args.tgt_ntype, int(tgt_nid)))
comp_graphs[src_tgt] = [comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids]
# Get labels with subgraph nids and eids
edge_labels = pred_pair_to_edge_labels[src_tgt]
comp_g_edge_labels = get_comp_g_edge_labels(comp_g, edge_labels)
path_labels = pred_pair_to_path_labels[src_tgt]
comp_g_path_labels = get_comp_g_path_labels(comp_g, path_labels)
comp_g_labels[src_tgt] = [comp_g_edge_labels, comp_g_path_labels]
explanation_masks = {}
for explainer in args.eval_explainer_names:
saved_explanation_mask = f'{explainer}_{args.saved_model_name}_pred_edge_to_comp_g_edge_mask'
saved_file = Path.cwd().joinpath(args.saved_explanation_dir, saved_explanation_mask)
with open(saved_file, "rb") as f:
explanation_masks[explainer] = pickle.load(f)
print('Dataset:', args.dataset_name)
for explainer in args.eval_explainer_names:
print(explainer)
print('-'*30)
pred_edge_to_comp_g_edge_mask = explanation_masks[explainer]
mask_auc_list = []
for src_tgt in comp_graphs:
comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids, = comp_graphs[src_tgt]
comp_g_edge_labels, comp_g_path_labels = comp_g_labels[src_tgt]
comp_g_edge_mask_dict = pred_edge_to_comp_g_edge_mask[src_tgt]
mask_auc = eval_edge_mask_auc(comp_g_edge_mask_dict, comp_g_edge_labels)
mask_auc_list += [mask_auc]
avg_auc = np.mean(mask_auc_list)
# Print
np.set_printoptions(precision=4, suppress=True)
print(f'Average Mask-AUC: {avg_auc : .4f}')
print('-'*30, '\n')
if args.eval_path_hit:
topks = [3, 5, 10, 20, 50, 100, 200]
for explainer in args.eval_explainer_names:
print(explainer)
print('-'*30)
pred_edge_to_comp_g_edge_mask = explanation_masks[explainer]
explainer_to_topk_path_hit = defaultdict(list)
for src_tgt in comp_graphs:
comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids, = comp_graphs[src_tgt]
comp_g_path_labels = comp_g_labels[src_tgt][1]
comp_g_edge_mask_dict = pred_edge_to_comp_g_edge_mask[src_tgt]
topk_to_path_hit = eval_edge_mask_topk_path_hit(comp_g_edge_mask_dict, comp_g_path_labels, topks)
for topk in topk_to_path_hit:
explainer_to_topk_path_hit[topk] += [topk_to_path_hit[topk]]
# Take average
explainer_to_topk_path_hit_rate = defaultdict(list)
for topk in explainer_to_topk_path_hit:
metric = np.array(explainer_to_topk_path_hit[topk])
explainer_to_topk_path_hit_rate[topk] = metric.mean(0)
# Print
np.set_printoptions(precision=4, suppress=True)
for k, hr in explainer_to_topk_path_hit_rate.items():
print(f'k: {k :3} | Path HR: {hr.item(): .4f}')
print('-'*30, '\n')
explainer.py
import dgl
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
from collections import defaultdict
from utils import get_ntype_hetero_nids_to_homo_nids, get_homo_nids_to_ntype_hetero_nids, get_ntype_pairs_to_cannonical_etypes
from utils import hetero_src_tgt_khop_in_subgraph, get_neg_path_score_func, k_shortest_paths_with_max_length
def get_edge_mask_dict(ghetero):
'''
Create a dictionary mapping etypes to learnable edge masks
Parameters
----------
ghetero : heterogeneous dgl graph.
Return
----------
edge_mask_dict : dictionary
key=etype, value=torch.nn.Parameter with size number of etype edges
'''
device = ghetero.device
edge_mask_dict = {}
for etype in ghetero.canonical_etypes:
num_edges = ghetero.num_edges(etype)
num_nodes = ghetero.edge_type_subgraph([etype]).num_nodes()
std = torch.nn.init.calculate_gain('relu') * np.sqrt(2.0 / (2 * num_nodes))
edge_mask_dict[etype] = torch.nn.Parameter(torch.randn(num_edges, device=device) * std)
return edge_mask_dict
def remove_edges_of_high_degree_nodes(ghomo, max_degree=10, always_preserve=[]):
'''
For all the nodes with degree higher than `max_degree`,
except nodes in `always_preserve`, remove their edges.
Parameters
----------
ghomo : dgl homogeneous graph
max_degree : int
always_preserve : iterable
These nodes won't be pruned.
Returns
-------
low_degree_ghomo : dgl homogeneous graph
Pruned graph with edges of high degree nodes removed
'''
d = ghomo.in_degrees()
high_degree_mask = d > max_degree
# preserve nodes
high_degree_mask[always_preserve] = False
high_degree_nids = ghomo.nodes()[high_degree_mask]
u, v = ghomo.edges()
high_degree_edge_mask = torch.isin(u, high_degree_nids) | torch.isin(v, high_degree_nids)
high_degree_u, high_degree_v = u[high_degree_edge_mask], v[high_degree_edge_mask]
high_degree_eids = ghomo.edge_ids(high_degree_u, high_degree_v)
low_degree_ghomo = dgl.remove_edges(ghomo, high_degree_eids)
return low_degree_ghomo
def remove_edges_except_k_core_graph(ghomo, k, always_preserve=[]):
'''
Find the `k`-core of `ghomo`.
Only isolate the low degree nodes by removing theirs edges
instead of removing the nodes, so node ids can be kept.
Parameters
----------
ghomo : dgl homogeneous graph
k : int
always_preserve : iterable
These nodes won't be pruned.
Returns
-------
k_core_ghomo : dgl homogeneous graph
The k-core graph
'''
k_core_ghomo = ghomo
degrees = k_core_ghomo.in_degrees()
k_core_mask = (degrees > 0) & (degrees < k)
k_core_mask[always_preserve] = False
while k_core_mask.any():
k_core_nids = k_core_ghomo.nodes()[k_core_mask]
u, v = k_core_ghomo.edges()
k_core_edge_mask = torch.isin(u, k_core_nids) | torch.isin(v, k_core_nids)
k_core_u, k_core_v = u[k_core_edge_mask], v[k_core_edge_mask]
k_core_eids = k_core_ghomo.edge_ids(k_core_u, k_core_v)
k_core_ghomo = dgl.remove_edges(k_core_ghomo, k_core_eids)
degrees = k_core_ghomo.in_degrees()
k_core_mask = (degrees > 0) & (degrees < k)
k_core_mask[always_preserve] = False
return k_core_ghomo
def get_eids_on_paths(paths, ghomo):
'''
Collect all edge ids on the paths
Note: The current version is a list version. An edge may be collected multiple times
A different version is a set version where an edge can only contribute one time
even it appears in multiple paths
Parameters
----------
ghomo : dgl homogeneous graph
Returns
-------
paths: list of lists
Each list contains (source node ids, target node ids)
'''
row, col = ghomo.edges()
eids = []
for path in paths:
for i in range(len(path)-1):
eids += [((row == path[i]) & (col == path[i+1])).nonzero().item()]
return torch.LongTensor(eids)
def comp_g_paths_to_paths(comp_g, comp_g_paths):
paths = []
g_nids = comp_g.ndata[dgl.NID]
for comp_g_path in comp_g_paths:
path = []
for can_etype, u, v in comp_g_path:
u_ntype, _, v_ntype = can_etype
path += [(can_etype, g_nids[u_ntype][u].item(), g_nids[v_ntype][v].item())]
paths += [path]
return paths
class PaGELink(nn.Module):
"""Path-based GNN Explanation for Heterogeneous Link Prediction (PaGELink)
Some methods are adapted from the DGL GNNExplainer implementation
https://docs.dgl.ai/en/0.8.x/_modules/dgl/nn/pytorch/explain/gnnexplainer.html#GNNExplainer
Parameters
----------
model : nn.Module
The GNN-based link prediction model to explain.
* The required arguments of its forward function are source node id, target node id,
graph, and feature ids. The feature ids are for selecting input node features.
* It should also optionally take an eweight argument for edge weights
and multiply the messages by the weights during message passing.
* The output of its forward function is the logits in (-inf, inf) for the
predicted link.
lr : float, optional
The learning rate to use, default to 0.01.
num_epochs : int, optional
The number of epochs to train.
alpha1 : float, optional
A higher value will make the explanation edge masks more sparse by decreasing
the sum of the edge mask.
alpha2 : float, optional
A higher value will make the explanation edge masks more discrete by decreasing
the entropy of the edge mask.
alpha : float, optional
A higher value will make edges on high-quality paths to have higher weights
beta : float, optional
A higher value will make edges off high-quality paths to have lower weights
log : bool, optional
If True, it will log the computation process, default to True.
"""
def __init__(self,
model,
lr=0.001,
num_epochs=100,
alpha=1.0,
beta=1.0,
log=False):
super(PaGELink, self).__init__()
self.model = model
self.src_ntype = model.src_ntype
self.tgt_ntype = model.tgt_ntype
self.lr = lr
self.num_epochs = num_epochs
self.alpha = alpha
self.beta = beta
self.log = log
self.all_loss = defaultdict(list)
def _init_masks(self, ghetero):
"""Initialize the learnable edge mask.
Parameters
----------
graph : DGLGraph
Input graph.
Returns
-------
edge_mask_dict : dict
key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
"""
return get_edge_mask_dict(ghetero)
def _prune_graph(self, ghetero, prune_max_degree=-1, k_core=2, always_preserve=[]):
# Prune edges by (optionally) removing edges of high degree nodes and extracting k-core
# The pruning is computed on the homogeneous graph, i.e., ignoring node/edge types
ghomo = dgl.to_homogeneous(ghetero)
device = ghetero.device
ghomo.edata['eid_before_prune'] = torch.arange(ghomo.num_edges()).to(device)
if prune_max_degree > 0:
max_degree_pruned_ghomo = remove_edges_of_high_degree_nodes(ghomo, prune_max_degree, always_preserve)
k_core_ghomo = remove_edges_except_k_core_graph(max_degree_pruned_ghomo, k_core, always_preserve)
if k_core_ghomo.num_edges() <= 0: # no k-core found
pruned_ghomo = max_degree_pruned_ghomo
else:
pruned_ghomo = k_core_ghomo
else:
k_core_ghomo = remove_edges_except_k_core_graph(ghomo, k_core, always_preserve)
if k_core_ghomo.num_edges() <= 0: # no k-core found
pruned_ghomo = ghomo
else:
pruned_ghomo = k_core_ghomo
pruned_ghomo_eids = pruned_ghomo.edata['eid_before_prune']
pruned_ghomo_eid_mask = torch.zeros(ghomo.num_edges()).bool()
pruned_ghomo_eid_mask[pruned_ghomo_eids] = True
# Apply the pruning result on the heterogeneous graph
etypes_to_pruned_ghetero_eid_masks = {}
pruned_ghetero = ghetero
cum_num_edges = 0
for etype in ghetero.canonical_etypes:
num_edges = ghetero.num_edges(etype=etype)
pruned_ghetero_eid_mask = pruned_ghomo_eid_mask[cum_num_edges:cum_num_edges+num_edges]
etypes_to_pruned_ghetero_eid_masks[etype] = pruned_ghetero_eid_mask
remove_ghetero_eids = (~ pruned_ghetero_eid_mask).nonzero().view(-1).to(device)
pruned_ghetero = dgl.remove_edges(pruned_ghetero, eids=remove_ghetero_eids, etype=etype)
cum_num_edges += num_edges
return pruned_ghetero, etypes_to_pruned_ghetero_eid_masks
def path_loss(self, src_nid, tgt_nid, g, eweights, num_paths=5):
"""Compute the path loss.
Parameters
----------
src_nid : int
source node id
tgt_nid : int
target node id
g : dgl graph
eweights : Tensor
Edge weights with shape equals the number of edges.
num_paths : int
Number of paths to compute path loss on
Returns
-------
loss : Tensor
The path loss
"""
neg_path_score_func = get_neg_path_score_func(g, 'eweight', [src_nid, tgt_nid])
paths = k_shortest_paths_with_max_length(g,
src_nid,
tgt_nid,
weight=neg_path_score_func,
k=num_paths)
eids_on_path = get_eids_on_paths(paths, g)
if eids_on_path.nelement() > 0:
loss_on_path = - eweights[eids_on_path].mean()
else:
loss_on_path = 0
eids_off_path_mask = ~torch.isin(torch.arange(eweights.shape[0]), eids_on_path)
if eids_off_path_mask.any():
loss_off_path = eweights[eids_off_path_mask].mean()
else:
loss_off_path = 0
loss = self.alpha * loss_on_path + self.beta * loss_off_path
self.all_loss['loss_on_path'] += [float(loss_on_path)]
self.all_loss['loss_off_path'] += [float(loss_off_path)]
return loss
def get_edge_mask(self,
src_nid,
tgt_nid,
ghetero,
feat_nids,
prune_max_degree=-1,
k_core=2,
prune_graph=True,
with_path_loss=True):
"""Learning the edge mask dict.
Parameters
----------
see the `explain` method.
Returns
-------
edge_mask_dict : dict
key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
"""
self.model.eval()
device = ghetero.device
ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)
homo_src_nid = ntype_hetero_nids_to_homo_nids[(self.src_ntype, int(src_nid))]
homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(self.tgt_ntype, int(tgt_nid))]
# Get the initial prediction.
with torch.no_grad():
score = self.model(src_nid, tgt_nid, ghetero, feat_nids)
pred = (score > 0).int().item()
if prune_graph:
# The pruned graph for mask learning
ml_ghetero, etypes_to_pruned_ghetero_eid_masks = self._prune_graph(ghetero,
prune_max_degree,
k_core,
[homo_src_nid, homo_tgt_nid])
else:
# The original graph for mask learning
ml_ghetero = ghetero
ml_edge_mask_dict = self._init_masks(ml_ghetero)
optimizer = torch.optim.Adam(ml_edge_mask_dict.values(), lr=self.lr)
if self.log:
pbar = tqdm(total=self.num_epochs)
eweight_norm = 0
EPS = 1e-3
for e in range(self.num_epochs):
# Apply sigmoid to edge_mask to get eweight
ml_eweight_dict = {etype: ml_edge_mask_dict[etype].sigmoid() for etype in ml_edge_mask_dict}
score = self.model(src_nid, tgt_nid, ml_ghetero, feat_nids, ml_eweight_dict)
pred_loss = (-1) ** pred * score.sigmoid().log()
self.all_loss['pred_loss'] += [pred_loss.item()]
ml_ghetero.edata['eweight'] = ml_eweight_dict
ml_ghomo = dgl.to_homogeneous(ml_ghetero, edata=['eweight'])
ml_ghomo_eweights = ml_ghomo.edata['eweight']
# Check for early stop
curr_eweight_norm = ml_ghomo_eweights.norm()
if abs(eweight_norm - curr_eweight_norm) < EPS:
break
eweight_norm = curr_eweight_norm
# Update with path loss
if with_path_loss:
path_loss = self.path_loss(homo_src_nid, homo_tgt_nid, ml_ghomo, ml_ghomo_eweights)
else:
path_loss = 0
loss = pred_loss + path_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.all_loss['total_loss'] += [loss.item()]
if self.log:
pbar.update(1)
if self.log:
pbar.close()
edge_mask_dict_placeholder = self._init_masks(ghetero)
edge_mask_dict = {}
if prune_graph:
# remove pruned edges
for etype in ghetero.canonical_etypes:
edge_mask = edge_mask_dict_placeholder[etype].data + float('-inf')
pruned_ghetero_eid_mask = etypes_to_pruned_ghetero_eid_masks[etype]
edge_mask[pruned_ghetero_eid_mask] = ml_edge_mask_dict[etype]
edge_mask_dict[etype] = edge_mask
else:
edge_mask_dict = ml_edge_mask_dict
edge_mask_dict = {k : v.detach() for k, v in edge_mask_dict.items()}
return edge_mask_dict
def get_paths(self,
src_nid,
tgt_nid,
ghetero,
edge_mask_dict,
num_paths=1,
max_path_length=3):
"""A postprocessing step that turns the `edge_mask_dict` into actual paths.
Parameters
----------
edge_mask_dict : dict
key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
Others: see the `explain` method.
Returns
-------
paths: list of lists
each list contains (cannonical edge type, source node ids, target node ids)
"""
ntype_pairs_to_cannonical_etypes = get_ntype_pairs_to_cannonical_etypes(ghetero)
eweight_dict = {etype: edge_mask_dict[etype].sigmoid() for etype in edge_mask_dict}
ghetero.edata['eweight'] = eweight_dict
# convert ghetero to ghomo and find paths
ghomo = dgl.to_homogeneous(ghetero, edata=['eweight'])
ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)
homo_src_nid = ntype_hetero_nids_to_homo_nids[(self.src_ntype, int(src_nid))]
homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(self.tgt_ntype, int(tgt_nid))]
neg_path_score_func = get_neg_path_score_func(ghomo, 'eweight', [src_nid.item(), tgt_nid.item()])
homo_paths = k_shortest_paths_with_max_length(ghomo,
homo_src_nid,
homo_tgt_nid,
weight=neg_path_score_func,
k=num_paths,
max_length=max_path_length)
paths = []
homo_nids_to_ntype_hetero_nids = get_homo_nids_to_ntype_hetero_nids(ghetero)
if len(homo_paths) > 0:
for homo_path in homo_paths:
hetero_path = []
for i in range(1, len(homo_path)):
homo_u, homo_v = homo_path[i-1], homo_path[i]
hetero_u_ntype, hetero_u_nid = homo_nids_to_ntype_hetero_nids[homo_u]
hetero_v_ntype, hetero_v_nid = homo_nids_to_ntype_hetero_nids[homo_v]
can_etype = ntype_pairs_to_cannonical_etypes[(hetero_u_ntype, hetero_v_ntype)]
hetero_path += [(can_etype, hetero_u_nid, hetero_v_nid)]
paths += [hetero_path]
else:
# A rare case, no paths found, take the top edges
cat_edge_mask = torch.cat([v for v in edge_mask_dict.values()])
M = len(cat_edge_mask)
k = min(num_paths * max_path_length, M)
threshold = cat_edge_mask.topk(k)[0][-1].item()
path = []
for etype in edge_mask_dict:
u, v = ghetero.edges(etype=etype)
topk_edge_mask = edge_mask_dict[etype] >= threshold
path += list(zip([etype] * topk_edge_mask.sum().item(), u[topk_edge_mask].tolist(), v[topk_edge_mask].tolist()))
paths = [path]
return paths
def explain(self,
src_nid,
tgt_nid,
ghetero,
num_hops=2,
prune_max_degree=-1,
k_core=2,
num_paths=1,
max_path_length=3,
prune_graph=True,
with_path_loss=True,
return_mask=False):
"""Return a path explanation of a predicted link
Parameters
----------
src_nid : int
source node id
tgt_nid : int
target node id
ghetero : dgl graph
num_hops : int
Number of hops to extract the computation graph, i.e. GNN # layers
prune_max_degree : int
If positive, prune the edges of graph nodes with degree larger than `prune_max_degree`
If -1, do nothing
k_core : int
k for the the k-core graph extraction
num_paths : int
Number of paths for the postprocessing path extraction
max_path_length : int
Maximum length of paths for the postprocessing path extraction
prune_graph : bool
If true apply the max_degree and/or k-core pruning. For ablation. Default True.
with_path_loss : bool
If true include the path loss. For ablation. Default True.
return_mask : bool
If true return the edge mask in addition to the path. For AUC evaluation. Default False
Returns
-------
paths: list of lists
each list contains (cannonical edge type, source node ids, target node ids)
(optional) edge_mask_dict : dict
key=`etype`, value=torch.nn.Parameter with size being the number of `etype` edges
"""
# Extract the computation graph (k-hop subgraph)
(comp_g_src_nid,
comp_g_tgt_nid,
comp_g,
comp_g_feat_nids) = hetero_src_tgt_khop_in_subgraph(self.src_ntype,
src_nid,
self.tgt_ntype,
tgt_nid,
ghetero,
num_hops)
# Learn the edge mask on the computation graph
comp_g_edge_mask_dict = self.get_edge_mask(comp_g_src_nid,
comp_g_tgt_nid,
comp_g,
comp_g_feat_nids,
prune_max_degree,
k_core,
prune_graph,
with_path_loss)
# Extract paths
comp_g_paths = self.get_paths(comp_g_src_nid,
comp_g_tgt_nid,
comp_g,
comp_g_edge_mask_dict,
num_paths,
max_path_length)
# Covert the node id in computation graph to original graph
paths = comp_g_paths_to_paths(comp_g, comp_g_paths)
if return_mask:
# return masks for easier evaluation
return paths, comp_g_edge_mask_dict
else:
return paths
model.py
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import HeteroEmbedding, EdgePredictor
'''
HeteroRGCN model adapted from the DGL official tutorial
https://docs.dgl.ai/en/0.6.x/tutorials/basics/5_hetero.html
https://docs.dgl.ai/en/0.8.x/tutorials/models/1_gnn/4_rgcn.html
'''
class HeteroRGCNLayer(nn.Module):
def __init__(self, in_size, out_size, etypes):
super(HeteroRGCNLayer, self).__init__()
# W_0 for transform the node's own feature
self.weight0 = nn.Linear(in_size, out_size)
# W_r for each relation
self.weight = nn.ModuleDict({
name : nn.Linear(in_size, out_size) for name in etypes
})
def forward(self, g, feat_dict, eweight_dict=None):
# The input is a dictionary of node features for each type
funcs = {}
if eweight_dict is not None:
# Store the sigmoid of edge weights
g.edata['_edge_weight'] = eweight_dict
for srctype, etype, dsttype in g.canonical_etypes:
# Compute h_0 = W_0 * h
h0 = self.weight0(feat_dict[srctype])
g.nodes[srctype].data['h0'] = h0
# Compute h_r = W_r * h
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
g.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
if eweight_dict is not None:
msg_fn = fn.u_mul_e('Wh_%s' % etype, '_edge_weight', 'm')
else:
msg_fn = fn.copy_u('Wh_%s' % etype, 'm')
funcs[(srctype, etype, dsttype)] = (msg_fn, fn.mean('m', 'h'))
def apply_func(nodes):
h = nodes.data['h'] + nodes.data['h0']
return {'h': h}
# Trigger message passing of multiple types.
# The first argument is the message passing functions for each relation.
# The second one is the type wise reducer, could be "sum", "max",
# "min", "mean", "stack"
g.multi_update_all(funcs, 'sum', apply_func)
# g.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : g.nodes[ntype].data['h'] for ntype in g.ntypes}
class HeteroRGCN(nn.Module):
def __init__(self, g, emb_dim, hidden_size, out_size):
super(HeteroRGCN, self).__init__()
self.emb = HeteroEmbedding({ntype : g.num_nodes(ntype) for ntype in g.ntypes}, emb_dim)
self.layer1 = HeteroRGCNLayer(emb_dim, hidden_size, g.etypes)
self.layer2 = HeteroRGCNLayer(hidden_size, out_size, g.etypes)
def forward(self, g, feat_nids=None, eweight_dict=None):
if feat_nids is None:
feat_dict = self.emb.weight
else:
feat_dict = self.emb(feat_nids)
h_dict = self.layer1(g, feat_dict, eweight_dict)
h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = self.layer2(g, h_dict, eweight_dict)
return h_dict
class HeteroLinkPredictionModel(nn.Module):
def __init__(self, encoder, src_ntype, tgt_ntype, link_pred_op='dot', **kwargs):
super().__init__()
self.encoder = encoder
self.predictor = EdgePredictor(op=link_pred_op, **kwargs)
self.src_ntype = src_ntype
self.tgt_ntype = tgt_ntype
def encode(self, g, feat_nids=None, eweight_dict=None):
h = self.encoder(g, feat_nids, eweight_dict)
return h
def forward(self, src_nids, tgt_nids, g, feat_nids=None, eweight_dict=None):
h = self.encode(g, feat_nids, eweight_dict)
src_h = h[self.src_ntype][src_nids]
tgt_h = h[self.tgt_ntype][tgt_nids]
score = self.predictor(src_h, tgt_h).view(-1)
return score
pagelink.py
import os
import torch
import argparse
import pickle
from tqdm.auto import tqdm
from pathlib import Path
from utils import set_seed, print_args, set_config_args
from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel
from explainer import PaGELink
parser = argparse.ArgumentParser(description='Explain link predictor')
parser.add_argument('--device_id', type=int, default=-1)
'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--test_ratio', type=float, default=0.2)
parser.add_argument('--max_num_samples', type=int, default=-1,
help='maximum number of samples to explain, for fast testing. Use all if -1')
'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)
parser.add_argument('--saved_model_dir', type=str, default='saved_models')
parser.add_argument('--saved_model_name', type=str, default='')
'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
help='operation passed to dgl.EdgePredictor')
'''
Explanation args
'''
parser.add_argument('--lr', type=float, default=0.01, help='explainer learning_rate')
parser.add_argument('--alpha', type=float, default=1.0, help='explainer on-path edge regularizer weight')
parser.add_argument('--beta', type=float, default=1.0, help='explainer off-path edge regularizer weight')
parser.add_argument('--num_hops', type=int, default=2, help='computation graph number of hops')
parser.add_argument('--num_epochs', type=int, default=20, help='How many epochs to learn the mask')
parser.add_argument('--num_paths', type=int, default=40, help='How many paths to generate')
parser.add_argument('--max_path_length', type=int, default=5, help='max lenght of generated paths')
parser.add_argument('--k_core', type=int, default=2, help='k for the k-core graph')
parser.add_argument('--prune_max_degree', type=int, default=200,
help='prune the graph such that all nodes have degree smaller than max_degree. No prune if -1')
parser.add_argument('--save_explanation', default=False, action='store_true',
help='Whether to save the explanation')
parser.add_argument('--saved_explanation_dir', type=str, default='saved_explanations',
help='directory of saved explanations')
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')
args = parser.parse_args()
if args.config_path:
args = set_config_args(args, args.config_path, args.dataset_name, 'pagelink')
if 'citation' in args.dataset_name:
args.src_ntype = 'author'
args.tgt_ntype = 'paper'
elif 'synthetic' in args.dataset_name:
args.src_ntype = 'user'
args.tgt_ntype = 'item'
if torch.cuda.is_available() and args.device_id >= 0:
device = torch.device('cuda', index=args.device_id)
else:
device = torch.device('cpu')
if args.link_pred_op in ['cat']:
pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
pred_kwargs = {}
if not args.saved_model_name:
args.saved_model_name = f'{args.dataset_name}_model'
print_args(args)
set_seed(0)
processed_g = load_dataset(args.dataset_dir, args.dataset_name, args.valid_ratio, args.test_ratio)[1]
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g.to(device) for g in processed_g]
encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)
state = torch.load(f'{args.saved_model_dir}/{args.saved_model_name}.pth', map_location='cpu')
model.load_state_dict(state)
pagelink = PaGELink(model,
lr=args.lr,
alpha=args.alpha,
beta=args.beta,
num_epochs=args.num_epochs,
log=True).to(device)
test_src_nids, test_tgt_nids = test_pos_g.edges()
test_ids = range(test_src_nids.shape[0])
if args.max_num_samples > 0:
test_ids = test_ids[:args.max_num_samples]
pred_edge_to_comp_g_edge_mask = {}
pred_edge_to_paths = {}
for i in tqdm(test_ids):
src_nid, tgt_nid = test_src_nids[i].unsqueeze(0), test_tgt_nids[i].unsqueeze(0)
with torch.no_grad():
pred = model(src_nid, tgt_nid, mp_g).sigmoid().item() > 0.5
if pred:
src_tgt = ((args.src_ntype, int(src_nid)), (args.tgt_ntype, int(tgt_nid)))
paths, comp_g_edge_mask_dict = pagelink.explain(src_nid,
tgt_nid,
mp_g,
args.num_hops,
args.prune_max_degree,
args.k_core,
args.num_paths,
args.max_path_length,
return_mask=True)
pred_edge_to_comp_g_edge_mask[src_tgt] = comp_g_edge_mask_dict
pred_edge_to_paths[src_tgt] = paths
if args.save_explanation:
if not os.path.exists(args.saved_explanation_dir):
os.makedirs(args.saved_explanation_dir)
saved_edge_explanation_file = f'pagelink_{args.saved_model_name}_pred_edge_to_comp_g_edge_mask'
saved_path_explanation_file = f'pagelink_{args.saved_model_name}_pred_edge_to_paths'
pred_edge_to_comp_g_edge_mask = {edge: {k: v.cpu() for k, v in mask.items()} for edge, mask in pred_edge_to_comp_g_edge_mask.items()}
saved_edge_explanation_path = Path.cwd().joinpath(args.saved_explanation_dir, saved_edge_explanation_file)
with open(saved_edge_explanation_path, "wb") as f:
pickle.dump(pred_edge_to_comp_g_edge_mask, f)
saved_path_explanation_path = Path.cwd().joinpath(args.saved_explanation_dir, saved_path_explanation_file)
with open(saved_path_explanation_path, "wb") as f:
pickle.dump(pred_edge_to_paths, f)
train_linkpred.py
import os
import torch
import torch.nn.functional as F
import copy
import argparse
from sklearn.metrics import roc_auc_score
from pathlib import Path
from utils import set_seed, negative_sampling, print_args, set_config_args
from data_processing import load_dataset
from model import HeteroRGCN, HeteroLinkPredictionModel
parser = argparse.ArgumentParser(description='Train a GNN-based link prediction model')
parser.add_argument('--device_id', type=int, default=-1)
'''
Dataset args
'''
parser.add_argument('--dataset_dir', type=str, default='datasets')
parser.add_argument('--dataset_name', type=str, default='aug_citation')
parser.add_argument('--valid_ratio', type=float, default=0.1)
parser.add_argument('--test_ratio', type=float, default=0.2)
'''
GNN args
'''
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--out_dim', type=int, default=128)
'''
Link predictor args
'''
parser.add_argument('--src_ntype', type=str, default='user', help='prediction source node type')
parser.add_argument('--tgt_ntype', type=str, default='item', help='prediction target node type')
parser.add_argument('--pred_etype', type=str, default='likes', help='prediction edge type')
parser.add_argument('--link_pred_op', type=str, default='dot', choices=['dot', 'cos', 'ele', 'cat'],
help='operation passed to dgl.EdgePredictor')
parser.add_argument('--lr', type=float, default=0.01, help='link predictor learning_rate')
parser.add_argument('--num_epochs', type=int, default=200, help='How many epochs to train')
parser.add_argument('--eval_interval', type=int, default=1, help="Evaluate once per how many epochs")
parser.add_argument('--save_model', default=False, action='store_true', help='Whether to save the model')
parser.add_argument('--saved_model_dir', type=str, default='saved_models', help='Where to save the model')
parser.add_argument('--sample_neg_edges', default=False, action='store_true',
help='If False, use fixed negative edges. If True, sample negative edges in each epoch')
parser.add_argument('--config_path', type=str, default='', help='path of saved configuration args')
args = parser.parse_args()
if 'synthetic' in args.dataset_name:
args.src_ntype = 'user'
args.tgt_ntype = 'item'
elif 'citation' in args.dataset_name:
args.src_ntype = 'author'
args.tgt_ntype = 'paper'
if torch.cuda.is_available() and args.device_id >= 0:
device = torch.device('cuda', index=args.device_id)
else:
device = torch.device('cpu')
if args.link_pred_op in ['cat']:
pred_kwargs = {"in_feats": args.out_dim, "out_feats": 1}
else:
pred_kwargs = {}
if args.config_path:
args = set_config_args(args, args.config_path, args.dataset_name, 'train_eval')
print_args(args)
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
device = scores.device
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).to(device)
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).detach().cpu().numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
return roc_auc_score(labels, scores)
def run():
set_seed(0)
best_val_auc = 0
pred_etype= args.pred_etype
train_pos_src_nids, train_pos_tgt_nids = train_pos_g.edges(etype=pred_etype)
val_pos_src_nids, val_pos_tgt_nids = val_pos_g.edges(etype=pred_etype)
val_neg_src_nids, val_neg_tgt_nids = val_neg_g.edges(etype=pred_etype)
test_pos_src_nids, test_pos_tgt_nids = test_pos_g.edges(etype=pred_etype)
test_neg_src_nids, test_neg_tgt_nids = test_neg_g.edges(etype=pred_etype)
train_neg_src_nids, train_neg_tgt_nids = train_neg_g.edges(etype=pred_etype)
for epoch in range(1, args.num_epochs+1):
train_pos_score = model(train_pos_src_nids, train_pos_tgt_nids, mp_g)
if args.sample_neg_edges:
train_neg_src_nids, train_neg_tgt_nids = negative_sampling(train_pos_g, pred_etype)
train_neg_score = model(train_neg_src_nids, train_neg_tgt_nids, mp_g)
loss = compute_loss(train_pos_score, train_neg_score)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % args.eval_interval == 0:
with torch.no_grad():
train_auc = compute_auc(train_pos_score, train_neg_score)
val_pos_score = model(val_pos_src_nids, val_pos_tgt_nids, mp_g)
val_neg_score = model(val_neg_src_nids, val_neg_tgt_nids, mp_g)
val_auc = compute_auc(val_pos_score, val_neg_score)
print('In epoch {}, loss: {:.4f}, train AUC: {:.4f}, val AUC: {:.4f}'.format(epoch, loss, train_auc, val_auc))
if val_auc > best_val_auc:
best_epoch = epoch
best_val_auc = val_auc
state = copy.deepcopy(model.state_dict())
with torch.no_grad():
model.eval()
model.load_state_dict(state)
test_pos_score = model(test_pos_src_nids, test_pos_tgt_nids, mp_g)
test_neg_score = model(test_neg_src_nids, test_neg_tgt_nids, mp_g)
test_auc = compute_auc(test_pos_score, test_neg_score)
print('Best epoch {}, val AUC: {:.4f}, test AUC: {:.4f}'.format(best_epoch, best_val_auc, test_auc))
processed_g = load_dataset(args.dataset_dir, args.dataset_name, args.valid_ratio, args.test_ratio)[1]
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g.to(device) for g in processed_g]
encoder = HeteroRGCN(mp_g, args.emb_dim, args.hidden_dim, args.out_dim)
model = HeteroLinkPredictionModel(encoder, args.src_ntype, args.tgt_ntype, args.link_pred_op, **pred_kwargs)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
run()
if args.save_model:
output_dir = Path.cwd().joinpath(args.saved_model_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(model.state_dict(), output_dir.joinpath(f"{args.dataset_name}_model.pth"))
utils.py
import dgl
import torch
import random
import textwrap
import yaml
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
from dgl.subgraph import khop_in_subgraph
from itertools import count
from heapq import heappop, heappush
from sklearn.metrics import roc_auc_score
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def print_args(args):
for k, v in vars(args).items():
print(f'{k:25} {v}')
def set_config_args(args, config_path, dataset_name, model_name=''):
with open(config_path, "r") as conf:
config = yaml.load(conf, Loader=yaml.FullLoader)[dataset_name]
if model_name:
config = config[model_name]
for key, value in config.items():
setattr(args, key, value)
return args
'''
Model training utils
'''
def idx_split(idx, ratio, seed=0):
"""
Randomly split `idx` into idx1 and idx2, where idx1 : idx2 = `ratio` : 1 - `ratio`
Parameters
----------
idx : tensor
ratio: float
Returns
----------
Two index (tensor) after split
"""
set_seed(seed)
n = len(idx)
cut = int(n * ratio)
idx_idx_shuffle = torch.randperm(n)
idx1_idx, idx2_idx = idx_idx_shuffle[:cut], idx_idx_shuffle[cut:]
idx1, idx2 = idx[idx1_idx], idx[idx2_idx]
assert((torch.cat([idx1, idx2]).sort()[0] == idx.sort()[0]).all())
return idx1, idx2
def eids_split(eids, val_ratio, test_ratio, seed=0):
"""
Split `eids` into three parts: train, valid, and test,
where train : valid : test = (1 - `val_ratio` - `test_ratio`) : `val_ratio` : `test_ratio`
Parameters
----------
eid : tensor
edge id
val_ratio : float
test_ratio : float
seed : int
Returns
----------
Three edge ids (tensor) after split
"""
train_ratio = (1 - val_ratio - test_ratio)
train_eids, pred_eids = idx_split(eids, train_ratio, seed)
val_eids, test_eids = idx_split(pred_eids, val_ratio / (1 - train_ratio), seed)
return train_eids, val_eids, test_eids
def negative_sampling(graph, pred_etype=None, num_neg_samples=None):
'''
Adapted from PyG negative_sampling function
https://pytorch-geometric.readthedocs.io/en/1.7.2/_modules/torch_geometric/utils/
negative_sampling.html#negative_sampling
Parameters
----------
graph : dgl graph
pred_etype : string
The edge type for prediction
num_neg_samples : int
Returns
----------
Two negative nids. Nids for src and tgt nodes of the `pred_etype`
'''
# src_N: total number of src nodes
# N (tgt_N): total number of tgt nodes
# M: total number of possible edges, square of src_N * tgt_N
# pos_M: number of positive samples (observed edges)
# neg_M: number of negative samples
pos_src_nids, pos_tgt_nids = graph.edges(etype=pred_etype)
if pred_etype is None:
N = graph.num_nodes()
M = N * N
else:
src_ntype, _, tgt_ntype = graph.to_canonical_etype(pred_etype)
src_N, N = graph.num_nodes(src_ntype), graph.num_nodes(tgt_ntype)
M = src_N * N
pos_M = pos_src_nids.shape[0]
neg_M = num_neg_samples or pos_M
neg_M = min(neg_M, M - pos_M) # incase M - pos_M < neg_M
# Percentage of edges to opos_tgt_nidsersample, so only need to sample once in most cases
alpha = abs(1 / (1 - 1.1 * (pos_M / M)))
size = min(M, int(alpha * neg_M))
perm = torch.tensor(random.sample(range(M), size))
idx = pos_src_nids * N + pos_tgt_nids
# mask = torch.from_npos_src_nidsmpy(np.isin(perm, idx.to('cppos_src_nids'))).to(torch.bool)
mask = torch.isin(perm, idx.to('cpu')).to(torch.bool)
perm = perm[~mask][:neg_M].to(pos_src_nids.device)
neg_src_nids = torch.div(perm, N, rounding_mode='floor')
neg_tgt_nids = perm % N
return neg_src_nids, neg_tgt_nids
'''
DGL graph manipulation utils
'''
def get_homo_nids_to_hetero_nids(ghetero):
'''
Create a dictionary mapping the node ids of the homogeneous version of the input graph
to the node ids of the input heterogeneous graph.
Parameters
----------
ghetero : heterogeneous dgl graph
Returns
----------
homo_nids_to_hetero_nids : dict
'''
ghomo = dgl.to_homogeneous(ghetero)
homo_nids = range(ghomo.num_nodes())
hetero_nids = ghomo.ndata[dgl.NID].tolist()
homo_nids_to_hetero_nids = dict(zip(homo_nids, hetero_nids))
return homo_nids_to_hetero_nids
def get_homo_nids_to_ntype_hetero_nids(ghetero):
'''
Create a dictionary mapping the node ids of the homogeneous version of the input graph
to tuples as (node type, node id) of the input heterogeneous graph.
Parameters
----------
ghetero : heterogeneous dgl graph
Returns
----------
homo_nids_to_ntype_hetero_nids : dict
'''
ghomo = dgl.to_homogeneous(ghetero)
homo_nids = range(ghomo.num_nodes())
ntypes = ghetero.ntypes
# This line relies on the default order of ntype_ids is the order in ghetero.ntypes
ntypes = [ntypes[i] for i in ghomo.ndata[dgl.NTYPE]]
hetero_nids = ghomo.ndata[dgl.NID].tolist()
ntypes_hetero_nids = list(zip(ntypes, hetero_nids))
homo_nids_to_ntype_hetero_nids = dict(zip(homo_nids, ntypes_hetero_nids))
return homo_nids_to_ntype_hetero_nids
def get_ntype_hetero_nids_to_homo_nids(ghetero):
'''
Create a dictionary mapping tuples as (node type, node id) of the input heterogeneous graph
to the node ids of the homogeneous version of the input graph.
Parameters
----------
ghetero : heterogeneous dgl graph
Returns
----------
ntype_hetero_nids_to_homo_nids : dict
'''
tmp = get_homo_nids_to_ntype_hetero_nids(ghetero)
ntype_hetero_nids_to_homo_nids = {v: k for k, v in tmp.items()}
return ntype_hetero_nids_to_homo_nids
def get_ntype_pairs_to_cannonical_etypes(ghetero, pred_etype='likes'):
'''
Create a dictionary mapping tuples as (source node type, target node type) to
cannonical edge types. Edges wity type `pred_etype` will be excluded.
A helper function for path finding.
Only works if there is only one edge type between any pair of node types.
Parameters
----------
ghetero : heterogeneous dgl graph
pred_etype : string
The edge type for prediction
Returns
----------
ntype_pairs_to_cannonical_etypes : dict
'''
ntype_pairs_to_cannonical_etypes = {}
for src_ntype, etype, tgt_ntype in ghetero.canonical_etypes:
if etype != pred_etype:
ntype_pairs_to_cannonical_etypes[(src_ntype, tgt_ntype)] = (src_ntype, etype, tgt_ntype)
return ntype_pairs_to_cannonical_etypes
def get_num_nodes_dict(ghetero):
'''
Create a dictionary containing number of nodes of all ntypes in a heterogeneous graph
Parameters
----------
ghetero : heterogeneous dgl graph
Returns
----------
num_nodes_dict : dict
key=node type, value=number of nodes
'''
num_nodes_dict = {}
for ntype in ghetero.ntypes:
num_nodes_dict[ntype] = ghetero.num_nodes(ntype)
return num_nodes_dict
def remove_all_edges_of_etype(ghetero, etype):
'''
Remove all edges with type `etype` from `ghetero`. If `etype` is not in `ghetero`, do nothing.
Parameters
----------
ghetero : heterogeneous dgl graph
etype : string or triple of strings
Edge type in simple form (string) or cannonical form (triple of strings)
Returns
----------
removed_ghetero : heterogeneous dgl graph
'''
etype = ghetero.to_canonical_etype(etype)
if etype in ghetero.canonical_etypes:
eids = ghetero.edges('eid', etype=etype)
removed_ghetero = dgl.remove_edges(ghetero, eids, etype=etype)
else:
removed_ghetero = ghetero
return removed_ghetero
def hetero_src_tgt_khop_in_subgraph(src_ntype, src_nid, tgt_ntype, tgt_nid, ghetero, k):
'''
Find the `k`-hop subgraph around the src node and tgt node in `ghetero`
The output will be the union of two subgraphs.
See the dgl `khop_in_subgraph` function as a referrence
https://docs.dgl.ai/en/0.9.x/generated/dgl.khop_in_subgraph.html
Parameters
----------
src_ntype: string
source node type
src_nid : int
source node id
tgt_ntype: string
target node type
tgt_nid : int
target node id
ghetero : heterogeneous dgl graph
k: int
Number of hops
Return
----------
sghetero_src_nid: int
id of the source node in the subgraph
sghetero_tgt_nid: int
id of the target node in the subgraph
sghetero : heterogeneous dgl graph
Union of two k-hop subgraphs
sghetero_feat_nid: Tensor
The original `ghetero` node ids of subgraph nodes, for feature identification
'''
# Extract k-hop subgraph centered at the (src, tgt) pair
src_nid = src_nid.item() if torch.is_tensor(src_nid) else src_nid
tgt_nid = tgt_nid.item() if torch.is_tensor(tgt_nid) else tgt_nid
if src_ntype == tgt_ntype:
pred_dict = {src_ntype: torch.tensor([src_nid, tgt_nid])}
sghetero, inv_dict = khop_in_subgraph(ghetero, pred_dict, k)
sghetero_src_nid = inv_dict[src_ntype][0]
sghetero_tgt_nid = inv_dict[tgt_ntype][1]
else:
pred_dict = {src_ntype: src_nid, tgt_ntype: tgt_nid}
sghetero, inv_dict = khop_in_subgraph(ghetero, pred_dict, k)
sghetero_src_nid = inv_dict[src_ntype]
sghetero_tgt_nid = inv_dict[tgt_ntype]
sghetero_feat_nid = sghetero.ndata[dgl.NID]
return sghetero_src_nid, sghetero_tgt_nid, sghetero, sghetero_feat_nid
'''
Path finding utils
'''
def get_neg_path_score_func(g, weight, exclude_node=[]):
'''
Compute the negative path score for the shortest path algorithm.
Parameters
----------
g : dgl graph
weight: string
The edge weights stored in g.edata
exclude_node : iterable
Degree of these nodes will be set to 0 when computing the path score, so they will likely be included.
Returns
----------
neg_path_score_func: callable function
Takes in two node ids and return the edge weight.
'''
log_eweights = g.edata[weight].log().tolist()
log_in_degrees = g.in_degrees().log()
log_in_degrees[exclude_node] = 0
log_in_degrees = log_in_degrees.tolist()
u, v = g.edges()
neg_path_score_map = {edge : log_in_degrees[edge[1]] - log_eweights[i] for i, edge in enumerate(zip(u.tolist(), v.tolist()))}
def neg_path_score_func(u, v):
return neg_path_score_map[(u, v)]
return neg_path_score_func
def bidirectional_dijkstra(g, src_nid, tgt_nid, weight=None, ignore_nodes=None, ignore_edges=None):
"""Dijkstra's algorithm for shortest paths using bidirectional search.
Adapted from NetworkX _bidirectional_dijkstra
https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html
Parameters
----------
g : dgl graph
src_nid : int
source node id
tgt_nid : int
target node id
weight: callable function, optional
Takes in two node ids and return the edge weight.
ignore_nodes : container of nodes
nodes to ignore, optional
ignore_edges : container of edges
edges to ignore, optional
Returns
-------
length : number
Shortest path length.
"""
if src_nid == tgt_nid:
return (0, [src_nid])
src, tgt = g.edges()
Gpred = lambda i: src[tgt == i].tolist()
Gsucc = lambda i: tgt[src == i].tolist()
if ignore_nodes:
def filter_iter(nodes):
def iterate(v):
for w in nodes(v):
if w not in ignore_nodes:
yield w
return iterate
Gpred = filter_iter(Gpred)
Gsucc = filter_iter(Gsucc)
if ignore_edges:
def filter_pred_iter(pred_iter):
def iterate(v):
for w in pred_iter(v):
if (w, v) not in ignore_edges:
yield w
return iterate
def filter_succ_iter(succ_iter):
def iterate(v):
for w in succ_iter(v):
if (v, w) not in ignore_edges:
yield w
return iterate
Gpred = filter_pred_iter(Gpred)
Gsucc = filter_succ_iter(Gsucc)
push = heappush
pop = heappop
# Init: Forward Backward
dists = [{}, {}] # dictionary of final distances
paths = [{src_nid: [src_nid]}, {tgt_nid: [tgt_nid]}] # dictionary of paths
fringe = [[], []] # heap of (distance, node) tuples for
# extracting next node to expand
seen = [{src_nid: 0}, {tgt_nid: 0}] # dictionary of distances to
# nodes seen
c = count()
# initialize fringe heap
push(fringe[0], (0, next(c), src_nid))
push(fringe[1], (0, next(c), tgt_nid))
# neighs for extracting correct neighbor information
neighs = [Gsucc, Gpred]
# variables to hold shortest discovered path
# finaldist = 1e30000
finalpath = []
dir = 1
if not weight:
weight = lambda u, v: 1
while fringe[0] and fringe[1]:
# choose direction
# dir == 0 is forward direction and dir == 1 is back
dir = 1 - dir
# extract closest to expand
(dist, _, v) = pop(fringe[dir])
if v in dists[dir]:
# Shortest path to v has already been found
continue
# update distance
dists[dir][v] = dist # equal to seen[dir][v]
if v in dists[1 - dir]:
# if we have scanned v in both directions we are done
# we have now discovered the shortest path
return (finaldist, finalpath)
for w in neighs[dir](v):
if dir == 0: # forward
minweight = weight(v, w)
vwLength = dists[dir][v] + minweight
else: # back, must remember to change v,w->w,v
minweight = weight(w, v)
vwLength = dists[dir][v] + minweight
if w in dists[dir]:
if vwLength < dists[dir][w]:
raise ValueError("Contradictory paths found: negative weights?")
elif w not in seen[dir] or vwLength < seen[dir][w]:
# relaxing
seen[dir][w] = vwLength
push(fringe[dir], (vwLength, next(c), w))
paths[dir][w] = paths[dir][v] + [w]
if w in seen[0] and w in seen[1]:
# see if this path is better than the already
# discovered shortest path
totaldist = seen[0][w] + seen[1][w]
if finalpath == [] or finaldist > totaldist:
finaldist = totaldist
revpath = paths[1][w][:]
revpath.reverse()
finalpath = paths[0][w] + revpath[1:]
raise ValueError("No paths found")
class PathBuffer:
"""For shortest paths finding
Adapted from NetworkX shortest_simple_paths
https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html
"""
def __init__(self):
self.paths = set()
self.sortedpaths = list()
self.counter = count()
def __len__(self):
return len(self.sortedpaths)
def push(self, cost, path):
hashable_path = tuple(path)
if hashable_path not in self.paths:
heappush(self.sortedpaths, (cost, next(self.counter), path))
self.paths.add(hashable_path)
def pop(self):
(cost, num, path) = heappop(self.sortedpaths)
hashable_path = tuple(path)
self.paths.remove(hashable_path)
return path
def k_shortest_paths_generator(g,
src_nid,
tgt_nid,
weight=None,
k=5,
ignore_nodes_init=None,
ignore_edges_init=None):
"""Generate at most `k` simple paths in the graph g from src_nid to tgt_nid,
each with maximum lenghth `max_length`, return starting from the shortest ones.
If a weighted shortest path search is to be used, no negative weights are allowed.
Adapted from NetworkX shortest_simple_paths
https://networkx.org/documentation/stable/_modules/networkx/algorithms/simple_paths.html
Parameters
----------
g : dgl graph
src_nid : int
source node id
tgt_nid : int
target node id
weight: callable function, optional
Takes in two node ids and return the edge weight.
k: int
number of paths
ignore_nodes_init : set of nodes
nodes to ignore, optional
ignore_edges_init : set of edges
edges to ignore, optional
Returns
-------
path_generator: generator
A generator that produces lists of tuples (path score, path), in order from
shortest to longest. Each path is a list of node ids
"""
if not weight:
weight = lambda u, v: 1
def length_func(path):
return sum(weight(u, v) for (u, v) in zip(path, path[1:]))
listA = list()
listB = PathBuffer()
prev_path = None
while not prev_path or len(listA) < k:
if not prev_path:
length, path = bidirectional_dijkstra(g, src_nid, tgt_nid, weight, ignore_nodes_init, ignore_edges_init)
listB.push(length, path)
else:
ignore_nodes = set(ignore_nodes_init) if ignore_nodes_init else set()
ignore_edges = set(ignore_edges_init) if ignore_edges_init else set()
for i in range(1, len(prev_path)):
root = prev_path[:i]
root_length = length_func(root)
for path in listA:
if path[:i] == root:
ignore_edges.add((path[i - 1], path[i]))
try:
length, spur = bidirectional_dijkstra(g,
root[-1],
tgt_nid,
ignore_nodes=ignore_nodes,
ignore_edges=ignore_edges,
weight=weight)
path = root[:-1] + spur
listB.push(root_length + length, path)
except ValueError:
pass
ignore_nodes.add(root[-1])
if listB:
path = listB.pop()
yield path
listA.append(path)
prev_path = path
else:
break
def k_shortest_paths_with_max_length(g,
src_nid,
tgt_nid,
weight=None,
k=5,
max_length=None,
ignore_nodes=None,
ignore_edges=None):
"""Generate at most `k` simple paths in the graph g from src_nid to tgt_nid,
each with maximum lenghth `max_length`, return starting from the shortest ones.
If a weighted shortest path search is to be used, no negative weights are allowed.
Parameters
----------
See function `k_shortest_paths_generator`
Return
-------
paths: list of lists
Each list is a path containing node ids
"""
path_generator = k_shortest_paths_generator(g,
src_nid,
tgt_nid,
weight=weight,
k=k,
ignore_nodes_init=ignore_nodes,
ignore_edges_init=ignore_edges)
try:
if max_length:
paths = [path for path in path_generator if len(path) <= max_length + 1]
else:
paths = list(path_generator)
except ValueError:
paths = [[]]
return paths
'''
Evaluation utils
'''
def get_comp_g_edge_labels(comp_g, edge_labels):
"""Turn `edge_labels` with node ids in the original graph to
`comp_g_edge_labels` with node ids in the computation graph.
For easier evaluation.
Parameters
----------
comp_g : heterogeneous dgl graph
computation graph, with .ndata stores key dgl.NID
edge_labels : dict
key=edge type, value=(source node ids, target node ids)
Return
-------
comp_g_edge_labels: dict
key=edge type, value=a tensor of labels, each label is in {0, 1}
"""
ntype_to_tensor_nids_to_comp_g_nids = {}
ntypes_to_comp_g_max_nids = {}
ntypes_to_nids = comp_g.ndata[dgl.NID]
for ntype in ntypes_to_nids.keys():
nids = ntypes_to_nids[ntype]
if nids.numel() > 0:
max_nid = nids.max().item()
else:
max_nid = -1
ntypes_to_comp_g_max_nids[ntype] = max_nid
nids_to_comp_g_nids = torch.zeros(max_nid + 1).long() - 1
# The i-th entry will be the nid in comp_g for the i-th node in g
nids_to_comp_g_nids[nids] = torch.arange(nids.shape[0])
ntype_to_tensor_nids_to_comp_g_nids[ntype] = nids_to_comp_g_nids
comp_g_edge_labels = {}
for can_etype in edge_labels:
start_ntype, etype, end_ntype = can_etype
start_nids, end_nids = edge_labels[can_etype]
start_comp_g_max_nid, end_comp_g_max_nid = ntypes_to_comp_g_max_nids[start_ntype], ntypes_to_comp_g_max_nids[end_ntype]
# For edges in label but not in comp_g, exclude them
start_included_nid_mask = start_nids <= start_comp_g_max_nid
end_included_nid_mask = end_nids <= end_comp_g_max_nid
comp_g_included_nid_mask = end_included_nid_mask & start_included_nid_mask
start_nids = start_nids[comp_g_included_nid_mask]
end_nids = end_nids[comp_g_included_nid_mask]
comp_g_start_nids = ntype_to_tensor_nids_to_comp_g_nids[start_ntype][start_nids]
comp_g_end_nids = ntype_to_tensor_nids_to_comp_g_nids[end_ntype][end_nids]
comp_g_eids = comp_g.edge_ids(comp_g_start_nids.tolist(), comp_g_end_nids.tolist(), etype=etype)
num_edges = comp_g.num_edges(etype=can_etype)
comp_g_eid_mask = torch.zeros(num_edges)
comp_g_eid_mask[comp_g_eids] = 1
comp_g_edge_labels[can_etype] = comp_g_eid_mask
return comp_g_edge_labels
def get_comp_g_path_labels(comp_g, path_labels):
"""Turn `path_labels` with node ids in the original graph
`comp_g_path_labels` with node ids in the computation graph
For easier evaluation.
Parameters
----------
comp_g : heterogeneous dgl graph
computation graph, with .ndata stores key dgl.NID
path_labels : list of lists
Each list is a path, i.e., triples of
(cannonical edge type, source node id, target node id)
Returns
-------
comp_g_path_labels: list of lists
Each list is a path, i.e., tuples of (cannonical edge type, edge id)
"""
ntype_to_tensor_nids_to_comp_g_nids = {}
ntypes_to_comp_g_max_nids = {}
ntypes_to_nids = comp_g.ndata[dgl.NID]
for ntype in ntypes_to_nids.keys():
nids = ntypes_to_nids[ntype]
if nids.numel() > 0:
max_nid = nids.max().item()
else:
max_nid = -1
ntypes_to_comp_g_max_nids[ntype] = max_nid
nids_to_comp_g_nids = torch.zeros(max_nid + 1).long() - 1
# The i-th entry will be the nid in comp_g for the i-th node in g
nids_to_comp_g_nids[nids] = torch.arange(nids.shape[0])
ntype_to_tensor_nids_to_comp_g_nids[ntype] = nids_to_comp_g_nids
comp_g_path_labels = []
for path in path_labels:
comp_g_path = []
for can_etype, start_nid, end_nid in path:
start_ntype, etype, end_ntype = can_etype
comp_g_start_nid = ntype_to_tensor_nids_to_comp_g_nids[start_ntype][start_nid].item()
comp_g_end_nid = ntype_to_tensor_nids_to_comp_g_nids[end_ntype][end_nid].item()
comp_g_eid = comp_g.edge_ids(comp_g_start_nid, comp_g_end_nid, etype=can_etype)
comp_g_path += [(can_etype, comp_g_eid)]
comp_g_path_labels += [comp_g_path]
return comp_g_path_labels
def eval_edge_mask_auc(edge_mask_dict, edge_labels):
'''
Evaluate the AUC of an edge mask
Parameters
----------
edge_mask_dict: dict
key=edge type, value=a tensor of labels, each label is in (-inf, inf)
edge_labels: dict
key=edge type, value=a tensor of labels, each label is in {0, 1}
Returns
----------
ROC-AUC score : int
'''
y_true = []
y_score = []
for can_etype in edge_labels:
y_true += [edge_labels[can_etype]]
y_score += [edge_mask_dict[can_etype].detach().sigmoid()]
y_true = torch.cat(y_true)
y_score = torch.cat(y_score)
return roc_auc_score(y_true, y_score)
def eval_edge_mask_topk_path_hit(edge_mask_dict, path_labels, topks=[10]):
'''
Evaluate the path hit rate of the top k edges in an edge mask
Parameters
----------
edge_mask_dict: dict
key=edge type, value=a tensor of labels, each label is in (-inf, inf)
path_labels: list of lists
Each list is a path, i.e., tuples of (cannonical edge type, edge id)
topks: iterable
An iterable of the top `k` values. Each `k` determines how many edges to select
from the top values of the mask.
Returns
----------
topk_to_path_hit: dict
Mapping the top `k` to
'''
cat_edge_mask = torch.cat([v for v in edge_mask_dict.values()])
M = len(cat_edge_mask)
topks = {k: min(k, M) for k in topks}
topk_to_path_hit = defaultdict(list)
for r, k in topks.items():
threshold = cat_edge_mask.topk(k)[0][-1].item()
hard_edge_mask_dict = {}
for etype in edge_mask_dict:
hard_edge_mask_dict[etype] = edge_mask_dict[etype] >= threshold
hit = eval_hard_edge_mask_path_hit(hard_edge_mask_dict, path_labels)
topk_to_path_hit[r] += [hit]
return topk_to_path_hit
def eval_hard_edge_mask_path_hit(hard_edge_mask_dict, path_labels):
'''
Evaluate the path hit of the an hard edge mask
Parameters
----------
hard_edge_mask_dict: dict
key=edge type, value=a tensor of labels, each label is in {True, False}
path_labels: list of lists
Each list is a path, i.e., tuples of (cannonical edge type, edge id)
Returns
----------
hit_path: int
1 or 0
'''
for path in path_labels:
hit_path = 1
for can_etype, eid in path:
if not hard_edge_mask_dict[can_etype][eid]:
hit_path = 0
break
if hit_path:
return 1
return 0
def eval_path_explanation_edges_path_hit(path_explanation_edges, path_labels):
'''
Evaluate the path hit rate of the a path_explanation_edges
Parameters
----------
path_explanation_edges : list
Edges on the path explanation, each edge is a triples of
(cannonical edge type, source node id, target node id)
path_labels : list of lists
Each list is a path, i.e., triples of
(cannonical edge type, source node id, target node id)
Returns
----------
hit_path: int
1 or 0
'''
for path in path_labels:
hit_path = 1
for edge in path:
if edge not in path_explanation_edges:
hit_path = 0
break
if hit_path:
return 1
return 0
'''
Plotting utils
'''
def plot_hetero_graph(ghetero,
ntypes_to_nshapes=None,
ntypes_to_ncolors=None,
ntypes_to_nlayers=None,
layout='multipartite',
layout_seed=0,
node_size=1000,
edge_kwargs={},
selected_node_dict=None,
selected_node_color='red',
selected_edge_dict=None,
selected_edge_kwargs={},
label='nid',
etype_label=True,
label_offset=False,
title=None,
legend=False,
figsize=(10, 10),
fig_name=None,
fig_format='png',
is_show=True):
'''
Parameters
----------
ghetero: a DGL heterogeneous graph with ndata `order`
ntypes_to_nshapes : Dict
mapping node types to node shapes
ntypes_to_ncolors : Dict
mapping node types to node colors
ntypes_to_nlayers : Dict
mapping node types to layer order in the multipartite layout.
label: String
one of ['none', nid'] or a node feature stored in ndata of ghetero
Returns
----------
nx_graph : networkx graph
'''
if ntypes_to_nshapes is None:
default_node_shape = 'o'
if ntypes_to_ncolors is None:
default_node_color = 'cyan'
if selected_node_dict is not None:
selected_node_dict = {ntype: list(selected_node_dict[ntype]) for ntype in selected_node_dict}
# Convert DGL graph to networkx graph
ghomo = dgl.to_homogeneous(ghetero)
edges = torch.cat([t.unsqueeze(1) for t in ghomo.edges()], dim=1)
edge_list = [(n_frm, n_to) for (n_frm, n_to) in edges.tolist()]
nx_graph = dgl.to_networkx(ghomo, node_attrs=[dgl.NTYPE])
# Use different layout
if layout == 'spring':
pos = nx.spring_layout(nx_graph, seed=layout_seed)
elif layout == 'kk':
pos = nx.kamada_kawai_layout(nx_graph)
elif layout == 'multipartite':
if ntypes_to_nlayers is not None:
ntype_ids_to_nlayers = {ghetero.get_ntype_id(ntype): ntypes_to_nlayers[ntype] for ntype in ghetero.ntypes}
else:
ntype_ids_to_nlayers = {ghetero.get_ntype_id(ntype): i for i, ntype in enumerate(ghetero.ntypes)}
for i in nx_graph.nodes():
ntype_id = nx_graph.nodes()[i][dgl.NTYPE].item()
nx_graph.nodes()[i][dgl.NTYPE] = ntype_ids_to_nlayers[ntype_id]
pos = nx.multipartite_layout(nx_graph, subset_key=dgl.ETYPE, scale=1)
else:
raise ValueError('Unknown layout')
# Start drawing
plt.figure(figsize=figsize)
ax = plt.gca()
# Draw nodes for each ntype
for ntype in ghetero.ntypes:
ntype_ids = ghomo.ndata[dgl.NTYPE]
hetero_nids = ghomo.ndata[dgl.NID] # nid in the original hetero graph
node_shape = ntypes_to_nshapes[ntype] if ntypes_to_nshapes else default_node_shape
node_color = ntypes_to_ncolors[ntype] if ntypes_to_ncolors else default_node_color
# For the current node type, get the node type id and node ids
curr_ntype_id = ghetero.get_ntype_id(ntype)
curr_nids_mask = ntype_ids == curr_ntype_id
curr_nids = curr_nids_mask.nonzero().view(-1).tolist()
# For the current node type, get node ids and prediction node id in the original hetero graph
curr_hetero_nids = hetero_nids[curr_nids_mask]
if selected_node_dict is not None:
curr_hetero_selected_nid = selected_node_dict.get(ntype)
if curr_hetero_selected_nid is not None:
curr_node_color = []
for hetero_nid in curr_hetero_nids:
curr_node_color += [selected_node_color if hetero_nid in curr_hetero_selected_nid else node_color]
node_color = curr_node_color
nx.draw_networkx_nodes(nx_graph,
pos,
curr_nids,
node_shape=node_shape,
node_color=node_color,
node_size=node_size,
ax=ax)
# Draw edges
nx.draw_networkx_edges(nx_graph, pos, edge_list, **edge_kwargs, ax=ax)
if selected_edge_dict is not None:
ntype_hetero_nids_to_homo_nids = get_ntype_hetero_nids_to_homo_nids(ghetero)
homo_selected_edge_list = []
for etype in selected_edge_dict:
src_ntype, _, tgt_ntype = ghetero.to_canonical_etype(etype)
src_nids, tgt_nids = selected_edge_dict[etype]
for src_nid, tgt_nid in zip(src_nids.tolist(), tgt_nids.tolist()):
homo_src_nid = ntype_hetero_nids_to_homo_nids[(src_ntype, src_nid)]
homo_tgt_nid = ntype_hetero_nids_to_homo_nids[(tgt_ntype, tgt_nid)]
homo_selected_edge_list += [(homo_src_nid, homo_tgt_nid)]
nx.draw_networkx_edges(nx_graph, pos, homo_selected_edge_list, **selected_edge_kwargs, ax=ax)
# Start labelling nodes
if label == 'none':
pass
elif label == 'nid':
homo_nids_to_hetero_nids = get_homo_nids_to_hetero_nids(ghetero)
nx.draw_networkx_labels(nx_graph, pos, labels=homo_nids_to_hetero_nids)
else:
# Set extra space to avoid label outside of the box
x_values, y_values = zip(*pos.values())
x_max = max(x_values)
x_min = min(x_values)
x_margin = (x_max - x_min) * 0.12
ax.set_xlim(x_min - x_margin, x_max + x_margin)
if ghetero.ndata.get(label):
homo_nids_to_hetero_ndata_feat = get_homo_nids_to_hetero_ntype_data_feat(ghetero, label)
if label_offset:
offset = 0.8 / figsize[1]
label_pos = {nid : [p[0], p[1] - offset] for nid, p in pos.items()}
else:
label_pos = pos
nx.draw_networkx_labels(nx_graph,
label_pos,
font_size=14,
font_weight='bold',
labels=homo_nids_to_hetero_ndata_feat,
horizontalalignment='center',
verticalalignment='center',
ax=ax)
else:
raise ValueError('Unrecognized label')
# Start labelling edges with etype
if etype_label is not None:
if ghetero.ndata.get(label):
homo_nid_pairs_to_etypes = get_homo_nid_pairs_to_etypes(ghetero)
nx.draw_networkx_edge_labels(nx_graph,
pos,
font_size=13,
font_weight='bold',
edge_labels=homo_nid_pairs_to_etypes,
horizontalalignment='center',
verticalalignment='center',
ax=ax)
if legend:
plt.legend(ghetero.ntypes, fontsize=15, prop={'size': figsize[0]*2.5}, bbox_to_anchor = (1.15, 0.7))
ax.axis('off')
if title is not None:
plt.title(textwrap.fill(title, width=60))
if fig_name is not None:
plt.savefig(fig_name, format=fig_format, bbox_inches='tight')
if is_show:
plt.show()
if fig_name is not None:
plt.close()
return nx_graph
def get_homo_nids_to_hetero_ntype_data_feat(ghetero, feat=dgl.NID):
'''
Plotting helper function
'''
ghomo = dgl.to_homogeneous(ghetero)
homo_nids = range(ghomo.num_nodes())
hetero_ndata_feat = []
for ntype in ghetero.ntypes:
hetero_ndata_feat += [f'{ntype[0]}' + f'{feat}' for feat in ghetero.ndata[feat][ntype].tolist()]
homo_nids_to_hetero_ndata_feat = dict(zip(homo_nids, hetero_ndata_feat))
return homo_nids_to_hetero_ndata_feat
def get_homo_nid_pairs_to_etypes(ghetero):
'''
Plotting helper function
'''
ghomo = dgl.to_homogeneous(ghetero)
etypes = ghetero.etypes
etype_list = [etypes[etype_id] for etype_id in ghomo.edata[dgl.ETYPE]]
u, v = ghomo.edges()
homo_nid_pairs_to_etypes = dict(zip(zip(u.tolist(), v.tolist()), etype_list))
return homo_nid_pairs_to_etypes
代码执行流程:首先需要训练边预测模型,其次再去生成解释模型。
论文中实验举例是预测的不同类型节点的链接,当进行相同类型节点链接的时候,还需要进行代码的微调。其次,文中节点的边都是双向边,当加入单向链接的时候,数据处理也需要微调一下哦。祝大家好好学习,天天向上。