def forward(self, q_sub, q_rel, hidden, edges, nodes, old_nodes_new_idx, batchsize):
# def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx): ###RED-GNN
# edges: [N_edge_of_all_batch, 6]
# with (batch_idx, head, rela, tail, head_idx, tail_idx)
# note that head_idx and tail_idx are relative index
sub = edges[:,4]
rel = edges[:,2]
obj = edges[:,5]
hs = hidden[sub]
hr = self.rela_embed(rel)
r_idx = edges[:,0]
h_qr = self.rela_embed(q_rel)[r_idx]
n_node = nodes.shape[0] ##########相比RED-GNN,AdaProp添加的
message = hs + hr
#######################################################################################3相比RED-GNN,AdaProp添加的
# sample edges w.r.t. alpha
if self.n_edge_topk > 0:
# 计算边的权重 alpha
alpha = self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))).squeeze(-1)
# 使用 Gumbel Softmax 采样边的概率分布
edge_prob = F.gumbel_softmax(alpha, tau=1, hard=False)
# 选择概率最高的前 n_edge_topk 个边
topk_index = torch.argsort(edge_prob, descending=True)[:self.n_edge_topk]
# 创建一个全是零的张量,将被选择的边的位置设为1
edge_prob_hard = torch.zeros((alpha.shape[0])).cuda()
edge_prob_hard[topk_index] = 1
# 根据采样结果更新边的权重 alpha
alpha *= (edge_prob_hard - edge_prob.detach() + edge_prob)
# 对 alpha 进行 sigmoid 操作,并在最后一维添加一个维度
alpha = torch.sigmoid(alpha).unsqueeze(-1)
else:
# 如果没有进行 top-k 采样,则直接计算边的权重 alpha
alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) # [N_edge_of_all_batch, 1]
# aggregate message and then propagate
message = alpha * message
message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
hidden_new = self.act(self.W_h(message_agg)) # [n_node, dim]
hidden_new = hidden_new.clone()
# forward without node sampling
if self.n_node_topk <= 0:
return hidden_new
# forward with node sampling
# indexing sampling operation
# 创建一个包含所有节点索引的布尔张量,初始化为True,表示所有节点都不相同 bool_diff_node_idx
tmp_diff_node_idx = torch.ones(n_node)
# 将旧节点映射到新节点的索引设为0,表示它们是相同的节点,只有old_nodes_new_idx是相同的
tmp_diff_node_idx[old_nodes_new_idx] = 0
# 将张量转换为布尔类型张量,表示不同的节点为True,相同的节点为False
bool_diff_node_idx = tmp_diff_node_idx.bool()
# 从所有节点列表中提取不同的节点
diff_node = nodes[bool_diff_node_idx]
# 计算不同节点的权重
diff_node_logit = self.W_samp(hidden_new[bool_diff_node_idx]).squeeze(-1) # [all_batch_new_nodes]
# 创建一个全是负无穷大的张量,用于存储软节点 soft_all
node_scores = torch.ones((batchsize, self.n_ent)).cuda() * float('-inf')
# 在软节点张量中填充不同节点的权重
node_scores[diff_node[:,0], diff_node[:,1]] = diff_node_logit
# select top-k nodes 选择 top-k 节点
# (train mode) self.softmax == F.gumbel_softmax (训练模式)self.softmax == F.gumbel_softmax
# (eval mode) self.softmax == F.softmax (评估模式)self.softmax == F.softmax
# 对软节点张量进行 softmax 操作,得到概率分布
node_scores = self.softmax(node_scores) # [batchsize, n_ent]
# 从soft_all中提取top-k的概率值,并将其乘以一个放大因子,获取每个样本中概率最大的top-k个节点的索引
topk_index = torch.topk(node_scores, self.n_node_topk, dim=1).indices.reshape(-1)
# 为前k个节点创建批次索引 # 创建一个重复的索引张量
topk_batchidx = torch.arange(batchsize).repeat(self.n_node_topk,1).T.reshape(-1)
# 创建一个布尔张量,表示所有节点都是硬节点 创建一个形状为 (batchsize, n_ent) 的零张量,用于表示批次中前 k 个节点 hard_all
batch_topk_nodes = torch.zeros((batchsize, self.n_ent)).cuda()
# 将每个样本中top-k的节点标记为1
batch_topk_nodes[topk_batchidx, topk_index] = 1
# 通过索引 diff_node 张量,获取与不同节点对应的布尔值
bool_sampled_diff_nodes_idx = batch_topk_nodes[diff_node[:,0], diff_node[:,1]].bool()
# 得到相同节点的索引
bool_same_node_idx = ~bool_diff_node_idx.cuda()
# 更新相同节点索引,将 bool_sampled_diff_nodes_idx 中选中的节点位置标记为 True
bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes_idx
# update node embeddings
# 根据索引 diff_node 提取与不同节点对应的概率值 bool_sampled_diff_nodes = hard_all[diff_nodes[:,0], diff_nodes[:,1]]
diff_node_prob_hard = batch_topk_nodes[diff_node[:,0], diff_node[:,1]]
# 根据索引 diff_node 提取与不同节点对应的原始概率值
diff_node_prob = node_scores[diff_node[:,0], diff_node[:,1]]
#根据不同节点的采样概率和原始概率进行一些计算,并在最后一个维度上添加一个维度,然后将结果乘以与不同节点对应的部分
hidden_new[bool_diff_node_idx] *= (diff_node_prob_hard - diff_node_prob.detach() + diff_node_prob).unsqueeze(-1)
# extract sampled nodes an their embeddings
# 根据相同节点的索引提取相同节点
new_nodes = nodes[bool_same_node_idx]
# 根据相同节点的索引更新 hidden_new 张量,即保留与相同节点对应的部分
hidden_new = hidden_new[bool_same_node_idx]
return hidden_new, new_nodes, bool_same_node_idx
alpha *= (edge_prob_hard - edge_prob.detach() + edge_prob)