代码以及视频讲解
本文所涉及所有资源均在传知代码平台可获取
1. 论文概述
基于关系有向图(r-digraph)的知识图推理方法,旨在解决传统基于关系路径推理方法的局限性。r-digraph由重叠的关系路径组成,用于捕获知识图谱中的局部证据。相比于单一路径,r-digraph更为复杂,因此需要有效的构建和学习方法。
为了应对这一挑战,作者提出了一种图神经网络的变体,称为RED-GNN。RED-GNN利用动态规划对具有共享边的多个r-digraph进行递归编码,以有效地捕获图中的信息。此外,为了提高对查询相关信息的捕获,RED-GNN采用了查询依赖的注意机制,以选择与查询相关的关系边。
研究结果表明,RED-GNN方法不仅在效率上具有优势,而且在归纳和转换推理任务中表现出显著的性能提升,相比于现有方法。此外,RED-GNN学习到的注意权值还可以为知识图推理提供可解释的证据,从而提高推理结果的可解释性。
总的来说,该方法为解决知识图推理中的复杂性和可解释性问题提供了一种有效的解决方案,有望在知识图谱领域取得重要进展。
论文:Knowledge Graph Reasoning with Relational Digraph
代码:https://github.com/LARS-research/RED-GNN
在此基础上本文添加tensorboard可视化结果
2. 论文方法
文章提出了两种方法:RED-Simp和递归r-digraph编码。
在RED-Simp中,通过提取子图结构并运行消息传递来编码r-有向图。然而,这种方法的计算成本很高,因为它需要对每个可能的答案实体进行独立的计算。
为了提高效率,提出了递归r-digraph编码方法,利用了r-有向图中共享的信息。这种方法通过动态规划逐层地构建r-有向图,以便多个查询可以共享相同的计算。递归编码的关键优势是减少了计算的重复性,从而提高了效率。
除此之外,论文使用注意力机制来捕获查询相关的知识,并将其编码到r-有向图中。通过设计合适的消息传递函数和评分函数,他们能够对查询进行推理,并从中提取可解释的局部证据。这种方法提高了模型的表现,并使得推理结果更容易理解。
3. 实验部分
3.1 实验条件
所有实验都是用PyTorch框架用Python编写的,并在具有8GB内存的RTX 3070Ti GPU上运行。
3.2 数据集
WN18RR:数据集描述:WN18RR是一个基于WordNet的知识图谱,其中包含从WordNet中提取的实体和关系。
FB15k237:数据集描述:FB15k237是一个基于Freebase知识图谱的数据集,其中包含了大量真实世界中的实体和关系。
NELL-995:数据集描述:NELL-995是一个包含995个关系的知识图谱数据集,关系代表了从新闻网站中提取的知识。
3.3 实验步骤
step1:安装环境依赖
Requirements
- torch==1.12.1
- numpy==1.21.6
- torch_scatter 2.0.9
step2:进入项目路径
cd transductive
step3:进行训练
python train.py --data_path=data/family
3.4 实验结果
使用tensorboard可视化结果
4. 关键代码
class GNNLayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x):
super(GNNLayer, self).__init__()
self.n_rel = n_rel
self.in_dim = in_dim
self.out_dim = out_dim
self.attn_dim = attn_dim
self.act = act
self.rela_embed = nn.Embedding(2*n_rel+1, in_dim)
self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wqr_attn = nn.Linear(in_dim, attn_dim)
self.w_alpha = nn.Linear(attn_dim, 1)
self.W_h = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx):
# edges: [batch_idx, head, rela, tail, old_idx, new_idx]
sub = edges[:,4]
rel = edges[:,2]
obj = edges[:,5]
hs = hidden[sub]
hr = self.rela_embed(rel)
r_idx = edges[:,0]
h_qr = self.rela_embed(q_rel)[r_idx]
message = hs + hr
alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))))
message = alpha * message
message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
hidden_new = self.act(self.W_h(message_agg))
return hidden_new
class RED_GNN_trans(torch.nn.Module):
def __init__(self, params, loader):
super(RED_GNN_trans, self).__init__()
self.n_layer = params.n_layer
self.hidden_dim = params.hidden_dim
self.attn_dim = params.attn_dim
self.n_rel = params.n_rel
self.loader = loader
acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x:x}
act = acts[params.act]
self.gnn_layers = []
for i in range(self.n_layer):
self.gnn_layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
self.gnn_layers = nn.ModuleList(self.gnn_layers)
self.dropout = nn.Dropout(params.dropout)
self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) # get score
self.gate = nn.GRU(self.hidden_dim, self.hidden_dim)
def forward(self, subs, rels, mode='train'):
n = len(subs)
q_sub = torch.LongTensor(subs).cuda()
q_rel = torch.LongTensor(rels).cuda()
h0 = torch.zeros((1, n,self.hidden_dim)).cuda()
nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1)
hidden = torch.zeros(n, self.hidden_dim).cuda()
scores_all = []
for i in range(self.n_layer):
nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode)
hidden = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx)
h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0)
hidden = self.dropout(hidden)
hidden, h0 = self.gate (hidden.unsqueeze(0), h0)
hidden = hidden.squeeze(0)
scores = self.W_final(hidden).squeeze(-1)
scores_all = torch.zeros((n, self.loader.n_ent)).cuda() # non_visited entities have 0 scores
scores_all[[nodes[:,0], nodes[:,1]]] = scores
return scores_all