GraphENS 源码阅读笔记

GraphENS: Neighbor-Aware Ego Network Synthesis for Class-Imbalanced Node Classification
论文源码地址: https://github.com/JoonHyung-Park/GraphENS

论文完整算法
在这里插入图片描述

阅读论文源码时,建议按main函数中运行顺序阅读

只写了gens.py文件中代码的注释

def get_ins_neighbor_dist(num_nodes, edge_index, train_mask, device):
    """
    Compute adjacent node distribution.
    节点数量较多时,建议使用稀疏tensor
    """
    ## Utilize GPU ##
    train_mask = train_mask.clone().to(device)
    edge_index = edge_index.clone().to(device)
    row, col = edge_index[0], edge_index[1]

    # Compute neighbor distribution
    neighbor_dist_list = []
    # 计算每个节点的邻居分布
    for j in range(num_nodes):
        neighbor_dist = torch.zeros(num_nodes, dtype=torch.float32).to(device)

        # 统计入边
        idx = row[(col==j)]
        # 相连加1
        neighbor_dist[idx] = neighbor_dist[idx] + 1
        neighbor_dist_list.append(neighbor_dist)

    # 按行拼接
    neighbor_dist_list = torch.stack(neighbor_dist_list,dim=0)  # [num_nodes, num_nodes]
    # 每行归一化
    neighbor_dist_list = F.normalize(neighbor_dist_list,dim=1,p=1)

    return neighbor_dist_list

def sampling_idx_individual_dst(class_num_list, idx_info, device):
    """
    Samples source and target nodes
    过采样
    class_num_list 每类节点数量
    idx_info 每类节点     按类别排序
    """
    # Selecting src & dst nodes
    # 具有最多样本数的类
    max_num, n_cls = max(class_num_list), len(class_num_list)
    # 最大数量 - 每类的数量      sampling_list 每类采样数量
    # 源节点只包含 minor 节点,不包括最多类
    # 采样至比例为 1:1
    sampling_list = max_num * torch.ones(n_cls) - torch.tensor(class_num_list)
    new_class_num_list = torch.Tensor(class_num_list).to(device)

    # Compute # of source nodes
    # torch.randint(len(cls_idx), (int(samp_num.item()),)) 每类产生 samp_num 个节点
    sampling_src_idx =[cls_idx[torch.randint(len(cls_idx),(int(samp_num.item()),))]
                        for cls_idx, samp_num in zip(idx_info, sampling_list)]

    sampling_src_idx = torch.cat(sampling_src_idx)     # [# of augmented nodes,]

    # Generate corresponding destination nodes
    class_dst_idx= []
    # 每类采样概率,这里概率和不为1
    prob = torch.log(new_class_num_list.float())/ new_class_num_list.float()
    # 每个节点的采样概率    [N_nodes,]
    prob = prob.repeat_interleave(new_class_num_list.long())
    # 节点 这里是按类别排序
    temp_idx_info = torch.cat(idx_info)     # [# of nodes,]

    # 采样目标节点,返回对应索引
    dst_idx = torch.multinomial(prob, sampling_src_idx.shape[0], True)
    # 获取目标节点
    sampling_dst_idx = temp_idx_info[dst_idx]

    # Sorting src idx with corresponding dst idx
    # 升序排列节点
    sampling_src_idx, sorted_idx = torch.sort(sampling_src_idx)
    sampling_dst_idx = sampling_dst_idx[sorted_idx]

    return sampling_src_idx, sampling_dst_idx

def saliency_mixup(x, sampling_src_idx, sampling_dst_idx, lam, saliency=None,
                   dist_kl = None, keep_prob = 0.3):
    """
    Saliency-based node mixing - Mix node features
    Input:
        x:                  Node features; [# of nodes, input feature dimension]
        sampling_src_idx:   Source node index for augmented nodes; [# of augmented nodes]
        sampling_dst_idx:   Target node index for augmented nodes; [# of augmented nodes]
        lam:                Sampled mixing ratio; [# of augmented nodes, 1]
        saliency:           Saliency map of input feature; [# of nodes, input feature dimension]
        dist_kl:             KLD between source node and target node predictions; [# of augmented nodes, 1]
        keep_prob:          Ratio of keeping source node feature; scalar
    Output:
        new_x:              [# of original nodes + # of augmented nodes, feature dimension]
    """
    total_node = x.shape[0]
    ## Mixup ##
    new_src = x[sampling_src_idx.to(x.device), :].clone()
    new_dst = x[sampling_dst_idx.to(x.device), :].clone()
    lam = lam.to(x.device)

    # Saliency Mixup
    if saliency != None:
        node_dim = saliency.shape[1]
        saliency_dst = saliency[sampling_dst_idx].abs()
        saliency_dst += 1e-10
        # 构造一个分布
        # 原论文写的softmax?
        saliency_dst /= torch.sum(saliency_dst, dim=1).unsqueeze(1)    # [# of augmented nodes, node_dim]

        # 保留的src节点特征数量
        K = int(node_dim * keep_prob)
        # mask为True 表示保留src节点特征
        mask_idx = torch.multinomial(saliency_dst, K)   # [# of augmented nodes, K]
        lam = lam.expand(-1,node_dim).clone()   # [# of augmented nodes, node_dim]
        if dist_kl != None: # Adaptive
            # kl 散度越大, kl_mask值越大,保留更多src节点特征
            kl_mask = (torch.sigmoid(dist_kl/3.) * K).squeeze().long()  # [# of augmented nodes,]
            # [1, K]  >= [# of augmented nodes, 1]  --> [# of augmented nodes, K]
            # 0,1矩阵 每行前kl_mask个元素为0
            idx_matrix = (torch.arange(K).unsqueeze(dim=0).to(kl_mask.device) >= kl_mask.unsqueeze(dim=1))
            # 取mask_idx第一列,重复 K
            zero_repeat_idx = mask_idx[:,0:1].repeat(1,mask_idx.size(1))    # [# of augmented nodes, K]
            mask_idx[idx_matrix] = zero_repeat_idx[idx_matrix]  # [# of augmented nodes, K]

        # 每行对应mask_idx 置1
        lam[torch.arange(lam.shape[0]).unsqueeze(1), mask_idx] = 1.
    mixed_node = lam * new_src + (1-lam) * new_dst
    new_x = torch.cat([x, mixed_node], dim =0)
    return new_x

每行对应mask_idx 置1
lam[torch.arange(lam.shape[0]).unsqueeze(1), mask_idx] = 1.
在这里插入图片描述


def duplicate_neighbor(total_node, edge_index, sampling_src_idx):
    """
    Duplicate edges of source nodes for sampled nodes.
    Input:
        total_node:         # of nodes; scalar
        edge_index:         Edge index; [2, # of edges]
        sampling_src_idx:   Source node index for augmented nodes; [# of augmented nodes]
    Output:
        new_edge_index:     original_edge_index + duplicated_edge_index
    """
    device = edge_index.device

    # Assign node index for augmented nodes
    row, col = edge_index[0], edge_index[1]
    row, sort_idx = torch.sort(row)
    col = col[sort_idx]
    # 每个节点的入度
    # 有BUG? 若末尾节点入度为0,degree不会加入此节点,导致后续索引越界
    degree = scatter_add(torch.ones_like(col), col)
    # repeat前为 src 节点加入图后的编号
    # repeat 根据度数重复 这里没有0度节点
    new_row =(torch.arange(len(sampling_src_idx)).to(device)+ total_node).repeat_interleave(degree[sampling_src_idx])
    # 相同 src 节点计数
    temp = scatter_add(torch.ones_like(sampling_src_idx), sampling_src_idx).to(device)

    # Duplicate the edges of source nodes
    node_mask = torch.zeros(total_node, dtype=torch.bool)
    unique_src = torch.unique(sampling_src_idx)
    node_mask[unique_src] = True
    row_mask = node_mask[row]
    # row中从src节点出发所指向的节点
    edge_mask = col[row_mask]

    b_idx = torch.arange(len(unique_src)).to(device).repeat_interleave(degree[unique_src])
    # 一个src节点为一行,每列为src所指节点
    # 列数不足补 -1
    # to_dense_batch 具体看官网文档
    # 这里edge_dense的行数为存在两种情况
    # 列数为unique_src中最大入度
    edge_dense, _ = to_dense_batch(edge_mask, b_idx, fill_value=-1)
    if len(temp[temp!=0]) != edge_dense.shape[0]:   # 末尾存在度为0的节点
        cut_num =len(temp[temp!=0]) - edge_dense.shape[0]
        cut_temp = temp[temp!=0][:-cut_num]
    else:
        cut_temp = temp[temp!=0]    # [# of unique_src,]
    edge_dense = edge_dense.repeat_interleave(cut_temp, dim=0)
    new_col = edge_dense[edge_dense!= -1]
    inv_edge_index = torch.stack([new_col, new_row], dim=0)
    new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1)

    return new_edge_index

结合to_dense_batch函数理解

官方文档地址:to_dense_batch

  • b_idx中只存在unique_src中入度不为0的节点,且每个节点重复其入度
  • to_dense_batch相当于获取每个src节点所连的节点
  • 根据节点入度,有如下情况
    • 情况1:b_idx不连续,即unique_src中存在入度为0的节点,且该节点之后有入度不为0的节点,此时to_dense_batch的输出会存在占位情况
    • 情况2:b_idx连续,即unique_src中入度为0的节点,只出现在入度不为0节点之后

情况1:(注意 batch的内容
在这里插入图片描述
情况2:
在这里插入图片描述


def get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx):
    """
    Compute KL divergence
    """
    device = prev_out.device
    # kl 散度设置reduction='none'的计算结果为[# of sampling nodes, # of sampling nodes]
    dist_kl = F.kl_div(torch.log(prev_out[sampling_dst_idx.to(device)]), prev_out[sampling_src_idx.to(device)], \
                    reduction='none').sum(dim=1,keepdim=True)
    dist_kl[dist_kl<0] = 0
    return dist_kl  # [# of sampling nodes, 1]

def neighbor_sampling(total_node, edge_index, sampling_src_idx, sampling_dst_idx,
        neighbor_dist_list, prev_out, train_node_mask=None):
    """
    Neighbor Sampling - Mix adjacent node distribution and samples neighbors from it
    Input:
        total_node:         # of nodes; scalar
        edge_index:         Edge index; [2, # of edges]
        sampling_src_idx:   Source node index for augmented nodes; [# of augmented nodes]
        sampling_dst_idx:   Target node index for augmented nodes; [# of augmented nodes]
        neighbor_dist_list: Adjacent node distribution of whole nodes; [# of nodes, # of nodes]
        prev_out:           Model prediction of the previous step; [# of nodes, n_cls]
        train_node_mask:    Mask for not removed nodes; [# of nodes]
    Output:
        new_edge_index:     original edge index + sampled edge index
        dist_kl:            kl divergence of target nodes from source nodes; [# of sampling nodes, 1]
    """
    ## Exception Handling ##
    device = edge_index.device
    n_candidate = 1
    sampling_src_idx = sampling_src_idx.clone().to(device)

    # Find the nearest nodes and mix target pool
    # 论文公式(1) 概率混合
    if prev_out is not None:
        sampling_dst_idx = sampling_dst_idx.clone().to(device)
        dist_kl = get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx)     # [# of augmented nodes, 1]
        # kl散度生成一列0 kl取反 拼接 softmax
        # 这里同时计算了论文公式(1)中的两个混合系数
        ratio = F.softmax(torch.cat([dist_kl.new_zeros(dist_kl.size(0),1), -dist_kl], dim=1), dim=1)    # [# of augmented nodes, 2]
        # src节点混合
        # 这里ratio[:,:1]保持了  [# of sampling nodes,  1] 最后一维
        mixed_neighbor_dist = ratio[:,:1] * neighbor_dist_list[sampling_src_idx]    # [# of augmented nodes,  # of nodes]
        for i in range(n_candidate):
            # target 节点概率混合
            mixed_neighbor_dist += ratio[:,i+1:i+2] * neighbor_dist_list[sampling_dst_idx.unsqueeze(dim=1)[:,i]]
    else:
        mixed_neighbor_dist = neighbor_dist_list[sampling_src_idx]  # [# of augmented nodes, # of nodes]

    # Compute degree
    col = edge_index[1]
    # 每个节点入度
    degree = scatter_add(torch.ones_like(col), col)
    # 补充末尾入度为0的节点
    if len(degree) < total_node:
        degree = torch.cat([degree, degree.new_zeros(total_node-len(degree))],dim=0)
    if train_node_mask is None:
        train_node_mask = torch.ones_like(degree,dtype=torch.bool)
    # 相同入度有几个节点
    # 索引为对应度数,内容为个数
    degree_dist = scatter_add(torch.ones_like(degree[train_node_mask]), degree[train_node_mask]).to(device).type(torch.float32)

    # Sample degree for augmented nodes
    # 复制 degree_dist
    prob = degree_dist.unsqueeze(dim=0).repeat(len(sampling_src_idx),1)     # [# of augmented nodes, len(degree_dist)]
    # 每个src节点取一个度
    # torch.multinomial 返回的是prob每行的索引
    # aug_degree为采样到的度数
    aug_degree = torch.multinomial(prob, 1).to(device).squeeze(dim=1)   # (m)   [# of augmented nodes,]
    max_degree = degree.max().item() + 1
    # src节点度数和采样度数取最小的
    aug_degree = torch.min(aug_degree, degree[sampling_src_idx])    # [# of augmented nodes,]

    # Sample neighbors
    # 每个src节点采样最大度数个节点,new_tgt为节点索引
    new_tgt = torch.multinomial(mixed_neighbor_dist + 1e-12, max_degree)    # [# of augmented nodes, max_degree]
    # 从0到最大度数
    tgt_index = torch.arange(max_degree).unsqueeze(dim=0).to(device)    # [1, max_degree]
    #  [1, max_degree] - [# of sampling nodes, 1] --> [# of augmented nodes, max_degree]
    #  每个src节点只取aug_degree个节点
    new_col = new_tgt[(tgt_index - aug_degree.unsqueeze(dim=1) < 0)]
    new_row = (torch.arange(len(sampling_src_idx)).to(device)+ total_node)
    new_row = new_row.repeat_interleave(aug_degree)
    inv_edge_index = torch.stack([new_col, new_row], dim=0)
    new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1)

    return new_edge_index, dist_kl

class MeanAggregation(MessagePassing):
    def __init__(self):
        super(MeanAggregation, self).__init__(aggr='mean')

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值