GraphENS: Neighbor-Aware Ego Network Synthesis for Class-Imbalanced Node Classification
论文源码地址: https://github.com/JoonHyung-Park/GraphENS
论文完整算法
阅读论文源码时,建议按main函数中运行顺序阅读
只写了gens.py文件中代码的注释
def get_ins_neighbor_dist(num_nodes, edge_index, train_mask, device):
"""
Compute adjacent node distribution.
节点数量较多时,建议使用稀疏tensor
"""
## Utilize GPU ##
train_mask = train_mask.clone().to(device)
edge_index = edge_index.clone().to(device)
row, col = edge_index[0], edge_index[1]
# Compute neighbor distribution
neighbor_dist_list = []
# 计算每个节点的邻居分布
for j in range(num_nodes):
neighbor_dist = torch.zeros(num_nodes, dtype=torch.float32).to(device)
# 统计入边
idx = row[(col==j)]
# 相连加1
neighbor_dist[idx] = neighbor_dist[idx] + 1
neighbor_dist_list.append(neighbor_dist)
# 按行拼接
neighbor_dist_list = torch.stack(neighbor_dist_list,dim=0) # [num_nodes, num_nodes]
# 每行归一化
neighbor_dist_list = F.normalize(neighbor_dist_list,dim=1,p=1)
return neighbor_dist_list
def sampling_idx_individual_dst(class_num_list, idx_info, device):
"""
Samples source and target nodes
过采样
class_num_list 每类节点数量
idx_info 每类节点 按类别排序
"""
# Selecting src & dst nodes
# 具有最多样本数的类
max_num, n_cls = max(class_num_list), len(class_num_list)
# 最大数量 - 每类的数量 sampling_list 每类采样数量
# 源节点只包含 minor 节点,不包括最多类
# 采样至比例为 1:1
sampling_list = max_num * torch.ones(n_cls) - torch.tensor(class_num_list)
new_class_num_list = torch.Tensor(class_num_list).to(device)
# Compute # of source nodes
# torch.randint(len(cls_idx), (int(samp_num.item()),)) 每类产生 samp_num 个节点
sampling_src_idx =[cls_idx[torch.randint(len(cls_idx),(int(samp_num.item()),))]
for cls_idx, samp_num in zip(idx_info, sampling_list)]
sampling_src_idx = torch.cat(sampling_src_idx) # [# of augmented nodes,]
# Generate corresponding destination nodes
class_dst_idx= []
# 每类采样概率,这里概率和不为1
prob = torch.log(new_class_num_list.float())/ new_class_num_list.float()
# 每个节点的采样概率 [N_nodes,]
prob = prob.repeat_interleave(new_class_num_list.long())
# 节点 这里是按类别排序
temp_idx_info = torch.cat(idx_info) # [# of nodes,]
# 采样目标节点,返回对应索引
dst_idx = torch.multinomial(prob, sampling_src_idx.shape[0], True)
# 获取目标节点
sampling_dst_idx = temp_idx_info[dst_idx]
# Sorting src idx with corresponding dst idx
# 升序排列节点
sampling_src_idx, sorted_idx = torch.sort(sampling_src_idx)
sampling_dst_idx = sampling_dst_idx[sorted_idx]
return sampling_src_idx, sampling_dst_idx
def saliency_mixup(x, sampling_src_idx, sampling_dst_idx, lam, saliency=None,
dist_kl = None, keep_prob = 0.3):
"""
Saliency-based node mixing - Mix node features
Input:
x: Node features; [# of nodes, input feature dimension]
sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes]
sampling_dst_idx: Target node index for augmented nodes; [# of augmented nodes]
lam: Sampled mixing ratio; [# of augmented nodes, 1]
saliency: Saliency map of input feature; [# of nodes, input feature dimension]
dist_kl: KLD between source node and target node predictions; [# of augmented nodes, 1]
keep_prob: Ratio of keeping source node feature; scalar
Output:
new_x: [# of original nodes + # of augmented nodes, feature dimension]
"""
total_node = x.shape[0]
## Mixup ##
new_src = x[sampling_src_idx.to(x.device), :].clone()
new_dst = x[sampling_dst_idx.to(x.device), :].clone()
lam = lam.to(x.device)
# Saliency Mixup
if saliency != None:
node_dim = saliency.shape[1]
saliency_dst = saliency[sampling_dst_idx].abs()
saliency_dst += 1e-10
# 构造一个分布
# 原论文写的softmax?
saliency_dst /= torch.sum(saliency_dst, dim=1).unsqueeze(1) # [# of augmented nodes, node_dim]
# 保留的src节点特征数量
K = int(node_dim * keep_prob)
# mask为True 表示保留src节点特征
mask_idx = torch.multinomial(saliency_dst, K) # [# of augmented nodes, K]
lam = lam.expand(-1,node_dim).clone() # [# of augmented nodes, node_dim]
if dist_kl != None: # Adaptive
# kl 散度越大, kl_mask值越大,保留更多src节点特征
kl_mask = (torch.sigmoid(dist_kl/3.) * K).squeeze().long() # [# of augmented nodes,]
# [1, K] >= [# of augmented nodes, 1] --> [# of augmented nodes, K]
# 0,1矩阵 每行前kl_mask个元素为0
idx_matrix = (torch.arange(K).unsqueeze(dim=0).to(kl_mask.device) >= kl_mask.unsqueeze(dim=1))
# 取mask_idx第一列,重复 K
zero_repeat_idx = mask_idx[:,0:1].repeat(1,mask_idx.size(1)) # [# of augmented nodes, K]
mask_idx[idx_matrix] = zero_repeat_idx[idx_matrix] # [# of augmented nodes, K]
# 每行对应mask_idx 置1
lam[torch.arange(lam.shape[0]).unsqueeze(1), mask_idx] = 1.
mixed_node = lam * new_src + (1-lam) * new_dst
new_x = torch.cat([x, mixed_node], dim =0)
return new_x
每行对应mask_idx
置1
lam[torch.arange(lam.shape[0]).unsqueeze(1), mask_idx] = 1.
def duplicate_neighbor(total_node, edge_index, sampling_src_idx):
"""
Duplicate edges of source nodes for sampled nodes.
Input:
total_node: # of nodes; scalar
edge_index: Edge index; [2, # of edges]
sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes]
Output:
new_edge_index: original_edge_index + duplicated_edge_index
"""
device = edge_index.device
# Assign node index for augmented nodes
row, col = edge_index[0], edge_index[1]
row, sort_idx = torch.sort(row)
col = col[sort_idx]
# 每个节点的入度
# 有BUG? 若末尾节点入度为0,degree不会加入此节点,导致后续索引越界
degree = scatter_add(torch.ones_like(col), col)
# repeat前为 src 节点加入图后的编号
# repeat 根据度数重复 这里没有0度节点
new_row =(torch.arange(len(sampling_src_idx)).to(device)+ total_node).repeat_interleave(degree[sampling_src_idx])
# 相同 src 节点计数
temp = scatter_add(torch.ones_like(sampling_src_idx), sampling_src_idx).to(device)
# Duplicate the edges of source nodes
node_mask = torch.zeros(total_node, dtype=torch.bool)
unique_src = torch.unique(sampling_src_idx)
node_mask[unique_src] = True
row_mask = node_mask[row]
# row中从src节点出发所指向的节点
edge_mask = col[row_mask]
b_idx = torch.arange(len(unique_src)).to(device).repeat_interleave(degree[unique_src])
# 一个src节点为一行,每列为src所指节点
# 列数不足补 -1
# to_dense_batch 具体看官网文档
# 这里edge_dense的行数为存在两种情况
# 列数为unique_src中最大入度
edge_dense, _ = to_dense_batch(edge_mask, b_idx, fill_value=-1)
if len(temp[temp!=0]) != edge_dense.shape[0]: # 末尾存在度为0的节点
cut_num =len(temp[temp!=0]) - edge_dense.shape[0]
cut_temp = temp[temp!=0][:-cut_num]
else:
cut_temp = temp[temp!=0] # [# of unique_src,]
edge_dense = edge_dense.repeat_interleave(cut_temp, dim=0)
new_col = edge_dense[edge_dense!= -1]
inv_edge_index = torch.stack([new_col, new_row], dim=0)
new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1)
return new_edge_index
结合to_dense_batch
函数理解
官方文档地址:to_dense_batch
b_idx
中只存在unique_src
中入度不为0的节点,且每个节点重复其入度to_dense_batch
相当于获取每个src
节点所连的节点- 根据节点入度,有如下情况
- 情况1:
b_idx
不连续,即unique_src
中存在入度为0的节点,且该节点之后有入度不为0的节点,此时to_dense_batch
的输出会存在占位情况 - 情况2:
b_idx
连续,即unique_src
中入度为0的节点,只出现在入度不为0节点之后
- 情况1:
情况1:(注意 batch的内容)
情况2:
def get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx):
"""
Compute KL divergence
"""
device = prev_out.device
# kl 散度设置reduction='none'的计算结果为[# of sampling nodes, # of sampling nodes]
dist_kl = F.kl_div(torch.log(prev_out[sampling_dst_idx.to(device)]), prev_out[sampling_src_idx.to(device)], \
reduction='none').sum(dim=1,keepdim=True)
dist_kl[dist_kl<0] = 0
return dist_kl # [# of sampling nodes, 1]
def neighbor_sampling(total_node, edge_index, sampling_src_idx, sampling_dst_idx,
neighbor_dist_list, prev_out, train_node_mask=None):
"""
Neighbor Sampling - Mix adjacent node distribution and samples neighbors from it
Input:
total_node: # of nodes; scalar
edge_index: Edge index; [2, # of edges]
sampling_src_idx: Source node index for augmented nodes; [# of augmented nodes]
sampling_dst_idx: Target node index for augmented nodes; [# of augmented nodes]
neighbor_dist_list: Adjacent node distribution of whole nodes; [# of nodes, # of nodes]
prev_out: Model prediction of the previous step; [# of nodes, n_cls]
train_node_mask: Mask for not removed nodes; [# of nodes]
Output:
new_edge_index: original edge index + sampled edge index
dist_kl: kl divergence of target nodes from source nodes; [# of sampling nodes, 1]
"""
## Exception Handling ##
device = edge_index.device
n_candidate = 1
sampling_src_idx = sampling_src_idx.clone().to(device)
# Find the nearest nodes and mix target pool
# 论文公式(1) 概率混合
if prev_out is not None:
sampling_dst_idx = sampling_dst_idx.clone().to(device)
dist_kl = get_dist_kl(prev_out, sampling_src_idx, sampling_dst_idx) # [# of augmented nodes, 1]
# kl散度生成一列0 kl取反 拼接 softmax
# 这里同时计算了论文公式(1)中的两个混合系数
ratio = F.softmax(torch.cat([dist_kl.new_zeros(dist_kl.size(0),1), -dist_kl], dim=1), dim=1) # [# of augmented nodes, 2]
# src节点混合
# 这里ratio[:,:1]保持了 [# of sampling nodes, 1] 最后一维
mixed_neighbor_dist = ratio[:,:1] * neighbor_dist_list[sampling_src_idx] # [# of augmented nodes, # of nodes]
for i in range(n_candidate):
# target 节点概率混合
mixed_neighbor_dist += ratio[:,i+1:i+2] * neighbor_dist_list[sampling_dst_idx.unsqueeze(dim=1)[:,i]]
else:
mixed_neighbor_dist = neighbor_dist_list[sampling_src_idx] # [# of augmented nodes, # of nodes]
# Compute degree
col = edge_index[1]
# 每个节点入度
degree = scatter_add(torch.ones_like(col), col)
# 补充末尾入度为0的节点
if len(degree) < total_node:
degree = torch.cat([degree, degree.new_zeros(total_node-len(degree))],dim=0)
if train_node_mask is None:
train_node_mask = torch.ones_like(degree,dtype=torch.bool)
# 相同入度有几个节点
# 索引为对应度数,内容为个数
degree_dist = scatter_add(torch.ones_like(degree[train_node_mask]), degree[train_node_mask]).to(device).type(torch.float32)
# Sample degree for augmented nodes
# 复制 degree_dist
prob = degree_dist.unsqueeze(dim=0).repeat(len(sampling_src_idx),1) # [# of augmented nodes, len(degree_dist)]
# 每个src节点取一个度
# torch.multinomial 返回的是prob每行的索引
# aug_degree为采样到的度数
aug_degree = torch.multinomial(prob, 1).to(device).squeeze(dim=1) # (m) [# of augmented nodes,]
max_degree = degree.max().item() + 1
# src节点度数和采样度数取最小的
aug_degree = torch.min(aug_degree, degree[sampling_src_idx]) # [# of augmented nodes,]
# Sample neighbors
# 每个src节点采样最大度数个节点,new_tgt为节点索引
new_tgt = torch.multinomial(mixed_neighbor_dist + 1e-12, max_degree) # [# of augmented nodes, max_degree]
# 从0到最大度数
tgt_index = torch.arange(max_degree).unsqueeze(dim=0).to(device) # [1, max_degree]
# [1, max_degree] - [# of sampling nodes, 1] --> [# of augmented nodes, max_degree]
# 每个src节点只取aug_degree个节点
new_col = new_tgt[(tgt_index - aug_degree.unsqueeze(dim=1) < 0)]
new_row = (torch.arange(len(sampling_src_idx)).to(device)+ total_node)
new_row = new_row.repeat_interleave(aug_degree)
inv_edge_index = torch.stack([new_col, new_row], dim=0)
new_edge_index = torch.cat([edge_index, inv_edge_index], dim=1)
return new_edge_index, dist_kl
class MeanAggregation(MessagePassing):
def __init__(self):
super(MeanAggregation, self).__init__(aggr='mean')
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x)