def soft_to_hard(self, i, hidden, nodes, n_ent, batch_size, old_nodes_new_idx):
##比RED-GNN,AdaProp添加的
# 获取图中节点的总数
n_node = len(nodes)
# 创建一个包含所有节点索引的布尔张量,初始化为True,表示所有节点都不相同
bool_diff_node_idx = torch.ones(n_node).bool().cuda()
# 将旧节点映射到新节点的索引设为 False,表示它们是相同的节点,只有old_nodes_new_idx是相同的
bool_diff_node_idx[old_nodes_new_idx] = False
# 创建一个包含所有节点索引的布尔张量,表示相同的节点
bool_same_node_idx = ~bool_diff_node_idx
# 从所有节点列表中提取不同的节点
diff_nodes = nodes[bool_diff_node_idx]
# 计算不同节点的权重
diff_node_logits = self.Ws_layers[i](hidden[bool_diff_node_idx].detach()).squeeze(-1)
#软节点和硬节点都是处理不同的节点
# 创建一个全是负无穷大的张量,用于存储软节点
soft_all = torch.ones((batch_size, n_ent)) * float('-inf')
soft_all = soft_all.cuda()
# 在软节点张量中填充不同节点的权重
soft_all[diff_nodes[:,0], diff_nodes[:,1]] = diff_node_logits
# 对软节点张量进行 softmax 操作,得到概率分布
soft_all = F.softmax(soft_all, dim=-1)
# 从soft_all中提取top-k的概率值,并将其乘以一个放大因子
diff_node_logits = self.topk * soft_all[diff_nodes[:,0], diff_nodes[:,1]]
# 从soft_all中获取每个样本中概率最大的top-k个节点的索引
_, argtopk = torch.topk(soft_all, k=self.topk, dim=-1)
# 生成一个形状为(batch_size, self.topk)的张量,其中每行都是0到batch_size-1的整数,重复self.topk次
#为了在后续的操作中对每个样本进行索引,以便获取对应样本中top - k的节点。
r_idx = torch.arange(batch_size).unsqueeze(1).repeat(1,self.topk).cuda()
# 创建一个布尔张量,表示所有节点都是硬节点
hard_all = torch.zeros((batch_size, n_ent)).bool().cuda()
# 将每个样本中top-k的节点标记为True
hard_all[r_idx,argtopk] = True
# 从采样得到的节点中提取对应的布尔值
bool_sampled_diff_nodes = hard_all[diff_nodes[:,0], diff_nodes[:,1]]
# 根据布尔条件和 diff_node_logits 更新 hidden 张量
hidden[bool_diff_node_idx][bool_sampled_diff_nodes] *= (1 - diff_node_logits[bool_sampled_diff_nodes].detach() + diff_node_logits[bool_sampled_diff_nodes]).unsqueeze(1)
# 将 bool_sampled_diff_nodes 中为 True (硬节点)更新到 bool_same_node_idx 中
bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes
# 返回更新后的隐藏表示和相同节点的索引
return hidden, bool_same_node_idx