AdaProp-inductive关键代码

     def soft_to_hard(self, i, hidden, nodes, n_ent, batch_size, old_nodes_new_idx):
    ##比RED-GNN,AdaProp添加的
        # 获取图中节点的总数
        n_node = len(nodes)
        # 创建一个包含所有节点索引的布尔张量,初始化为True,表示所有节点都不相同
        bool_diff_node_idx = torch.ones(n_node).bool().cuda()
        # 将旧节点映射到新节点的索引设为 False,表示它们是相同的节点,只有old_nodes_new_idx是相同的
        bool_diff_node_idx[old_nodes_new_idx] = False
        # 创建一个包含所有节点索引的布尔张量,表示相同的节点
        bool_same_node_idx = ~bool_diff_node_idx
        # 从所有节点列表中提取不同的节点
        diff_nodes = nodes[bool_diff_node_idx]
        # 计算不同节点的权重
        diff_node_logits = self.Ws_layers[i](hidden[bool_diff_node_idx].detach()).squeeze(-1)

        #软节点和硬节点都是处理不同的节点
        # 创建一个全是负无穷大的张量,用于存储软节点
        soft_all = torch.ones((batch_size, n_ent)) * float('-inf')
        soft_all = soft_all.cuda()
        # 在软节点张量中填充不同节点的权重
        soft_all[diff_nodes[:,0], diff_nodes[:,1]] = diff_node_logits
        # 对软节点张量进行 softmax 操作,得到概率分布
        soft_all = F.softmax(soft_all, dim=-1)

        # 从soft_all中提取top-k的概率值,并将其乘以一个放大因子
        diff_node_logits = self.topk * soft_all[diff_nodes[:,0], diff_nodes[:,1]]
        # 从soft_all中获取每个样本中概率最大的top-k个节点的索引
        _, argtopk = torch.topk(soft_all, k=self.topk, dim=-1)
        # 生成一个形状为(batch_size, self.topk)的张量,其中每行都是0到batch_size-1的整数,重复self.topk次
        #为了在后续的操作中对每个样本进行索引,以便获取对应样本中top - k的节点。
        r_idx = torch.arange(batch_size).unsqueeze(1).repeat(1,self.topk).cuda()
        # 创建一个布尔张量,表示所有节点都是硬节点
        hard_all = torch.zeros((batch_size, n_ent)).bool().cuda()
        # 将每个样本中top-k的节点标记为True
        hard_all[r_idx,argtopk] = True
        # 从采样得到的节点中提取对应的布尔值
        bool_sampled_diff_nodes = hard_all[diff_nodes[:,0], diff_nodes[:,1]]

        # 根据布尔条件和 diff_node_logits 更新 hidden 张量
        hidden[bool_diff_node_idx][bool_sampled_diff_nodes] *= (1 - diff_node_logits[bool_sampled_diff_nodes].detach() + diff_node_logits[bool_sampled_diff_nodes]).unsqueeze(1)
        # 将 bool_sampled_diff_nodes 中为 True (硬节点)更新到 bool_same_node_idx 中
        bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes

        # 返回更新后的隐藏表示和相同节点的索引
        return hidden, bool_same_node_idx

2cab4105f28d47b9a37153ebdc7d1801.png a03026b1b20845cd8f7e018c021733ae.png

1d795a8e5edf4ca6a19dfb40f6a8daa6.png

054f3808a74e4c21ba1db18434c83005.png

cae2311d547c4b1f87db562b536e35f6.png

0c73e310fe554816a46aa23946c18550.png

177956e825d14cca99f9688e2243dca3.png

414126f4912e470ea6389caa44799830.png

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小蜗子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值