class RUN_GNN(torch.nn.Module):
######和RED-GNN的class RED_GNN_trans(torch.nn.Module):
def __init__(self, params, loader):
# 初始化方法,接受模型的参数(params)和数据加载器(loader)。
super(RUN_GNN, self).__init__() #调用父类torch.nn.Module的初始化方法。
self.n_layer = params.n_layer #设置模型的图卷积层数。
self.hidden_dim = params.hidden_dim #设置模型中节点表示的隐藏维度。
self.attn_dim = params.attn_dim #设置模型中注意力机制的维度。
self.n_rel = params.n_rel #设置模型中关系的数量。
self.loader = loader #设置数据加载器,用于获取图数据
acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} #定义激活函数,根据参数中指定的激活函数名称选择对应的激活函数。
act = acts[params.act]
######与RED-GNN相比新加的
# 根据参数中是否存在 uniform_parm 且其值大于0,设置是否使用均匀参数(uniform parameters)。
if "uniform_parm" in params and params.uniform_parm > 0:
self.uniform_parm=True
else:
self.uniform_parm=False
######
self.gnn_layers = [] # 创建G_GAT_Layer层列表
for i in range(self.n_layer):
# 将多个G_GAT_Layer层添加到列表中
self.gnn_layers.append(G_GAT_Layer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
self.gnn_layers = nn.ModuleList(self.gnn_layers)
"""
通过nn.ModuleList将gnn_layers列表转化为PyTorch的ModuleList对象,以便PyTorch能够正确地跟踪和管理模型中的这些层。
ModuleList是PyTorch提供的用于容纳子模块(layers)的容器,使得子模块能够被正确地添加到模型的参数列表中,从而在优化器中进行更新。
"""
######与RED-GNN相比新加的
self.n_extra_layer = params.n_extra_layer #设置额外的图卷积层的数量,从模型参数 (params) 中获取。
self.extra_gnn_layers = [] #创建一个空列表 (extra_gnn_layers) 用于存储额外的图卷积层。
for i in range(self.n_extra_layer):
self.extra_gnn_layers.append(G_GAT_Layer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
self.extra_gnn_layers = nn.ModuleList(self.extra_gnn_layers)
######
#创建丢弃层(dropout)、线性层(W_final)和QRFGU层(gate)。其中,W_final是一个线性层,将隐藏维度映射到大小为1的输出。
self.dropout = nn.Dropout(params.dropout)
self.W_final = nn.Linear(self.hidden_dim, 1, bias=False)
self.gate = QRFGU(self.hidden_dim, self.hidden_dim) # self.gate= nn.GRU(self.hidden_dim, self.hidden_dim) # 定义GRU(门控循环单元)层
def forward(self, subs, rels, mode='transductive'):
# 接受实体(subs)、关系(rels),以及模式(mode)作为输入参数,默认模式为'transductive'。
n = len(subs) # 获取输入实体(subs)的数量。
# 根据模式(mode)选择实体的数量,'transductive'模式下使用全部实体数量,否则使用独立实体的数量。
n_ent = self.loader.n_ent if mode == 'transductive' else self.loader.n_ent_ind
device = next(self.parameters()).device #获取模型的设备信息。
# 将输入的实体和关系转换为PyTorch的张量类型,并将其移动到GPU上。
q_sub = torch.LongTensor(subs).cuda()
q_rel = torch.LongTensor(rels).cuda()
h0 = torch.zeros((1, n, self.hidden_dim)).cuda() # 初始化一个全零张量作为初始隐藏状态(h0)。
nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1) # 创建节点张量,其中包含实体的索引和对应的实体编号。
hidden = torch.zeros(n, self.hidden_dim).cuda() # 初始化一个全零张量用于存储隐藏状态。
scores_all = []
# 循环遍历图神经网络的层数
for i in range(self.n_layer):
# 获取邻居节点、边信息以及节点索引映射
nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode)
if i < (self.n_layer-1):
# 在前 n_layer-1 层使用基本的 GNN 层进行更新
hidden, h_n_qr = self.gnn_layers[0](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx)
else:
# 在最后一层使用特殊的 GNN 层进行更新
hidden, h_n_qr = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx)
if torch.cuda.is_available():
# 如果CUDA可用,创建一个全零张量并将其移动到GPU上,然后使用索引映射从h0复制值到对应的位置
h0 = torch.zeros(nodes.size(0), hidden.size(1)).cuda().index_copy_(0, old_nodes_new_idx, h0)
else:
# 如果CUDA不可用,创建一个全零张量,并使用索引映射从h0复制值到对应的位置
h0 = torch.zeros(nodes.size(0), hidden.size(1)).index_copy_(0, old_nodes_new_idx, h0)
hidden = self.dropout(hidden) # 对隐藏状态进行Dropout操作
hidden = self.gate(hidden, h_n_qr, h0) # 使用门控函数更新隐藏状态,其中h_n_qr是关系嵌入,h0是先前的隐藏状态
h0 = hidden # 更新先前的隐藏状态为当前的隐藏状态
############与RED-GNN相比新加的
for i in range(self.n_extra_layer):
# 对于额外的图神经网络层
hidden = hidden[old_nodes_new_idx] # 使用节点索引映射更新隐藏状态
hidden, h_n_qr = self.extra_gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx) # 在额外的 GNN 层进行更新
hidden = self.dropout(hidden) # 对当前的隐藏状态进行Dropout操作
hidden = self.gate(hidden, h_n_qr, h0) # 使用门控函数更新隐藏状态,其中h_n_qr是关系嵌入,h0是先前的隐藏状态
h0 = hidden # 更新先前的隐藏状态为当前的隐藏状态
######
scores = self.W_final(hidden).squeeze(-1) # 最终的输出分数,通过线性层处理
scores_all = torch.zeros((n, n_ent)).to(device) # 初始化一个全零张量用于存储最终的分数
scores_all[[nodes[:, 0], nodes[:, 1]]] = scores # 将分数按照节点索引映射到相应位置
return scores_all # 返回最终的分数张量
class G_GAT_Layer(torch.nn.Module):
####和RED-GNN的class GNNLayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x: x):
super(G_GAT_Layer, self).__init__() # 初始化 G_GAT_Layer 类
self.n_rel = n_rel # 图中的关系数量
self.in_dim = in_dim # 输入节点的特征维度
self.out_dim = out_dim # 输出节点的特征维度
self.attn_dim = attn_dim # 注意力维度
self.act = act # 节点更新时使用的激活函数,默认为恒等映射(即不进行激活函数操作)
self.rela_embed = nn.Embedding(2 * n_rel + 1, in_dim) # 定义嵌入层,用于将关系索引映射为关系嵌入向量
################RED-GNN新加的
self.relu = nn.ReLU() # ReLU 激活函数
self.gate = QRFGU(in_dim, in_dim) # 定义 QRFGU 模块,用于计算唯一关系的标识
################
# 定义用于计算注意力的线性变换层
self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wqr_attn = nn.Linear(in_dim, attn_dim)
self.w_alpha = nn.Linear(attn_dim, 1)
# 定义将消息聚合到节点表示的线性变换层
self.W_h = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx):
################RED-GNN改进的的
unique_edges, inverse_indices = torch.unique(edges[:, [0, 2, 4]],dim=0, sorted=True, return_inverse=True) # 从边中提取唯一的边,以及其在原始边中的逆索引
sub = unique_edges[:, 2] # 获取唯一边的起始节点索引
rel = unique_edges[:, 1] # 获取唯一边的关系索引
# sub = edges[:, 4]
# rel = edges[:, 2]
################
obj = edges[:, 5] # 获取边中的目标节点索引
h_s = hidden[sub] # 从隐藏表示中获取起始节点的表示
h_r = self.rela_embed(rel) # 从关系嵌入矩阵中获取关系嵌入表示
r_idx = unique_edges[:, 0] # 获取唯一边的查询关系索引(在原始边中的索引) r_idx = edges[:, 0]
h_qr = self.rela_embed(q_rel)[r_idx] # 从关系嵌入矩阵中获取查询关系的嵌入表示
# ####################
# message = hs + hr
# alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))))
# message = alpha * message
# message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
#
# hidden_new = self.act(self.W_h(message_agg))
# return hidden_new
# ####################
node_group = torch.zeros(n_node, dtype=torch.int64).to(h_qr.device) #创建一个大小为 n_node 的零张量 node_group,用于存储每个节点对应的分组标识。
node_group[edges[:, 5]] = edges[:, 0] # 根据边的信息更新节点组,将边的第6列作为索引,边的第1列作为值
h_n_qr = self.rela_embed(q_rel)[node_group] # 利用 node_group 从关系嵌入矩阵中获取对应节点的查询关系嵌入向量。
#利用self.gate模块计算唯一关系的标识,这个标识由主体节点、关系和查询关系的嵌入向量计算得到。
unique_relation_identity = self.gate(h_r, h_qr, h_s)
# 将独特关系标识符作为唯一的消息
unique_message = unique_relation_identity
# 计算唯一的关注权重,使用两个线性层的输出和激活函数ReLU
unique_attend_weight = self.w_alpha(self.relu(self.Ws_attn(unique_message) + self.Wqr_attn(h_qr)))
# 计算唯一关注权重的指数,用于后续计算
unique_exp_attend = torch.exp(unique_attend_weight)
# 获取指数化后的关注权重,通过索引变换获取反向的指数
exp_attend = unique_exp_attend[inverse_indices]
#通过将注意力分数乘以唯一关系的标识,得到加权的唯一消息,再根据逆索引将其对应到原始边上,形成加权的消息向量。
unique_message = unique_exp_attend * unique_message
message = unique_message[inverse_indices]
# 利用 scatter 函数对注意力分数进行汇总,得到每个节点的总注意力。
sum_exp_attend = scatter(exp_attend, dim=0, index=obj, dim_size=n_node, reduce="sum")
# 利用 scatter 函数对加权消息进行汇总,得到每个节点的未归一化的消息汇总。
no_attend_message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
# 将未归一化的消息汇总除以总注意力,得到每个节点的归一化消息汇总。
message_agg = no_attend_message_agg / sum_exp_attend
# 对归一化的消息汇总进行线性变换,并应用激活函数 self.act 得到新的节点表示。
hidden_new = self.act(self.W_h(message_agg))
# 返回最终的隐藏状态和查询关系的嵌入向量
return hidden_new, h_n_qr
class QRFGU(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size # 隐藏状态的维度大小
# 门控网络,用于计算更新和重置的值
# 全连接层,输入维度为 hidden_size * 3,输出维度为 hidden_size * 2
# Sigmoid 激活函数,将输出值压缩到 (0, 1) 之间, 用于计算更新(update)和重置(reset)的门控值。
self.gate = nn.Sequential(nn.Linear(self.hidden_size * 3, self.hidden_size * 2), nn.Sigmoid())
# 隐藏状态变换网络,通过调整参数来学习隐藏状态的变换
# 全连接层,输入维度为 hidden_size * 2,输出维度为 hidden_size
# Tanh 激活函数,将输出值压缩到 (-1, 1) 之间, 用于计算更新后的隐藏状态的候选值。
self.hidden_trans = nn.Sequential(nn.Linear(self.hidden_size * 2, self.hidden_size), nn.Tanh())
self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于计算 Sigmoid 的值
self.tanh = nn.Tanh() # Tanh 激活函数,用于计算 Tanh 的值
def forward(self, message: torch.Tensor, query_r: torch.Tensor, hidden_state: torch.Tensor):
"""
:param message: message[batch_size,input_size]
:param query_r: query_r[batch_size,input_size]
:param hidden_state: if it is none,it will be allocated a zero tensor hidden state
:return:
"""
"""
前向传播函数
:param message: 消息张量,维度为 [batch_size, input_size]
:param query_r: 查询张量,维度为 [batch_size, input_size]
:param hidden_state: 隐藏状态张量,如果为 None,将被分配一个全零张量的隐藏状态
:return: 更新后的隐藏状态张量
"""
# 通过门控网络计算更新值和重置值
update_value, reset_value = self.gate(torch.cat([message, query_r, hidden_state], dim=1)).chunk(2, dim=1)
# 通过隐藏状态变换网络计算隐藏状态的候选值
hidden_candidate = self.hidden_trans(torch.cat([message, reset_value * hidden_state], dim=1))
# 更新隐藏状态,按照门控值进行加权求和
hidden_state = (1 - update_value) * hidden_state + update_value * hidden_candidate
# 返回更新后的隐藏状态
return hidden_state