AdaProp-transductive关键代码

    def forward(self, q_sub, q_rel, hidden, edges, nodes, old_nodes_new_idx, batchsize):
    # def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx):   ###RED-GNN
        # edges: [N_edge_of_all_batch, 6] 
        # with (batch_idx, head, rela, tail, head_idx, tail_idx)
        # note that head_idx and tail_idx are relative index
        sub = edges[:,4]
        rel = edges[:,2]
        obj = edges[:,5]

        hs = hidden[sub]
        hr = self.rela_embed(rel)

        r_idx = edges[:,0]
        h_qr = self.rela_embed(q_rel)[r_idx]

        n_node = nodes.shape[0]   ##########相比RED-GNN,AdaProp添加的

        message = hs + hr
        #######################################################################################3相比RED-GNN,AdaProp添加的
        # sample edges w.r.t. alpha
        if self.n_edge_topk > 0:
            # 计算边的权重 alpha
            alpha = self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))).squeeze(-1)
            # 使用 Gumbel Softmax 采样边的概率分布
            edge_prob = F.gumbel_softmax(alpha, tau=1, hard=False)
            # 选择概率最高的前 n_edge_topk 个边
            topk_index = torch.argsort(edge_prob, descending=True)[:self.n_edge_topk]
            # 创建一个全是零的张量,将被选择的边的位置设为1
            edge_prob_hard = torch.zeros((alpha.shape[0])).cuda()
            edge_prob_hard[topk_index] = 1
            # 根据采样结果更新边的权重 alpha
            alpha *= (edge_prob_hard - edge_prob.detach() + edge_prob)
            # 对 alpha 进行 sigmoid 操作,并在最后一维添加一个维度
            alpha = torch.sigmoid(alpha).unsqueeze(-1)
        else:
            # 如果没有进行 top-k 采样,则直接计算边的权重 alpha
            alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) # [N_edge_of_all_batch, 1]

        # aggregate message and then propagate
        message = alpha * message
        message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum')
        hidden_new = self.act(self.W_h(message_agg)) # [n_node, dim]
        hidden_new = hidden_new.clone()
        
        # forward without node sampling
        if self.n_node_topk <= 0:
            return hidden_new

        # forward with node sampling
        # indexing sampling operation

        # 创建一个包含所有节点索引的布尔张量,初始化为True,表示所有节点都不相同  bool_diff_node_idx
        tmp_diff_node_idx = torch.ones(n_node)
        # 将旧节点映射到新节点的索引设为0,表示它们是相同的节点,只有old_nodes_new_idx是相同的
        tmp_diff_node_idx[old_nodes_new_idx] = 0
        # 将张量转换为布尔类型张量,表示不同的节点为True,相同的节点为False
        bool_diff_node_idx = tmp_diff_node_idx.bool()
        # 从所有节点列表中提取不同的节点
        diff_node = nodes[bool_diff_node_idx]
        # 计算不同节点的权重
        diff_node_logit = self.W_samp(hidden_new[bool_diff_node_idx]).squeeze(-1) # [all_batch_new_nodes]

        # 创建一个全是负无穷大的张量,用于存储软节点 soft_all
        node_scores = torch.ones((batchsize, self.n_ent)).cuda() * float('-inf')
        # 在软节点张量中填充不同节点的权重
        node_scores[diff_node[:,0], diff_node[:,1]] = diff_node_logit
        # select top-k nodes   选择 top-k 节点
        # (train mode) self.softmax == F.gumbel_softmax  (训练模式)self.softmax == F.gumbel_softmax
        # (eval mode)  self.softmax == F.softmax   (评估模式)self.softmax == F.softmax
        # 对软节点张量进行 softmax 操作,得到概率分布
        node_scores = self.softmax(node_scores) # [batchsize, n_ent]

        # 从soft_all中提取top-k的概率值,并将其乘以一个放大因子,获取每个样本中概率最大的top-k个节点的索引
        topk_index = torch.topk(node_scores, self.n_node_topk, dim=1).indices.reshape(-1)
        # 为前k个节点创建批次索引 # 创建一个重复的索引张量
        topk_batchidx = torch.arange(batchsize).repeat(self.n_node_topk,1).T.reshape(-1)
        # 创建一个布尔张量,表示所有节点都是硬节点  创建一个形状为 (batchsize, n_ent) 的零张量,用于表示批次中前 k 个节点 hard_all
        batch_topk_nodes = torch.zeros((batchsize, self.n_ent)).cuda()
        # 将每个样本中top-k的节点标记为1
        batch_topk_nodes[topk_batchidx, topk_index] = 1
        # 通过索引 diff_node 张量,获取与不同节点对应的布尔值
        bool_sampled_diff_nodes_idx = batch_topk_nodes[diff_node[:,0], diff_node[:,1]].bool()

        # 得到相同节点的索引
        bool_same_node_idx = ~bool_diff_node_idx.cuda()
        # 更新相同节点索引,将 bool_sampled_diff_nodes_idx 中选中的节点位置标记为 True
        bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes_idx

        # update node embeddings
        # 根据索引 diff_node 提取与不同节点对应的概率值  bool_sampled_diff_nodes = hard_all[diff_nodes[:,0], diff_nodes[:,1]]
        diff_node_prob_hard = batch_topk_nodes[diff_node[:,0], diff_node[:,1]]
        # 根据索引 diff_node 提取与不同节点对应的原始概率值
        diff_node_prob = node_scores[diff_node[:,0], diff_node[:,1]]
        #根据不同节点的采样概率和原始概率进行一些计算,并在最后一个维度上添加一个维度,然后将结果乘以与不同节点对应的部分
        hidden_new[bool_diff_node_idx] *= (diff_node_prob_hard - diff_node_prob.detach() + diff_node_prob).unsqueeze(-1)
        
        # extract sampled nodes an their embeddings
        # 根据相同节点的索引提取相同节点
        new_nodes = nodes[bool_same_node_idx]
        # 根据相同节点的索引更新 hidden_new 张量,即保留与相同节点对应的部分
        hidden_new = hidden_new[bool_same_node_idx]

        return hidden_new, new_nodes, bool_same_node_idx

a427f1e3be4f4b649c59f3c7ad9dba89.png

alpha *= (edge_prob_hard - edge_prob.detach() + edge_prob)

d444d29887e74f9eb66c0ebccff0b9c6.png

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小蜗子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值