论文笔记
1.主题
To summarize, our main contributions are:
•
Model: We introduce a graph neural network (GNN) based model for overlapping community detection.
•
Data: We introduce 4 new datasets for overlapping community detection that can act as a benchmark and stimulate future research in this area.
• Experiments: We perform a thorough evaluation of our model and show its superior performance compared to established methods for overlapping community detection, both in terms of speed and accuracy.
2.定义
3.模型(THE NOCD MODEL)
创新点:将GNN和Bernoulli–Poisson probabilistic model结合起来
邻接矩阵可以表示为:
Fu是F矩阵的第u行,也可以说是u节点的F向量
模型定义
具化为:
其中:
损失函数设计
Bernoulli–Poisson模型的负对数似然估计:
因为A必定是个稀疏矩阵,(u,v)存在边的概率会很小,所以上面公式的第二项的占比会比较重,所以优化以上函数:
评价
与其他算法在不同数据集上的比较
其中NMI值越高效果越好
为什么要用GNN?(与传统的MLP比较)
MLP结构如下
代码详解
1.数据处理
先将数据处理成Graph,结构如下:
loader = nocd.data.load_dataset('data/mag_cs.npz')
A, X, Z_gt = loader['A'], loader['X'], loader['Z']
N, K = Z_gt.shape
2.设置超参数
hidden_sizes = [128] # hidden sizes of the GNN
weight_decay = 1e-2 # strength of L2 regularization on GNN weights
dropout = 0.5 # whether to use dropout
batch_norm = True # whether to use batch norm
lr = 1e-3 # learning rate
max_epochs = 500 # number of epochs to train
display_step = 25 # how often to compute validation loss
balance_loss = True # whether to use balanced loss
stochastic_loss = True # whether to use stochastic or full-batch training
batch_size = 20000 # batch size (only for stochastic training)
3.选择(有些特别依赖于标签的网络结构,例如论文引用网络)并归一化特征矩阵(或者输入A,A+X)
x_norm = normalize(X) # node features
# x_norm = normalize(A) # adjacency matrix
# x_norm = sp.hstack([normalize(X), normalize(A)]) # concatenate A and X
x_norm = nocd.utils.to_sparse_tensor(x_norm).cuda()
x_norm的类型转换:
4.定义GNN模型
4.1导入数据集
sampler = nocd.sampler.get_edge_sampler(A, batch_size, batch_size, num_workers=2)
dataset[A]分为2**32长度个item,每个item包含batch_size个已连接的边和batch_size个未连接的边,类型为Tensor
4.2GNN模型定义
gnn = nocd.nn.GCN(x_norm.shape[1], hidden_sizes, K, batch_norm=batch_norm, dropout=dropout).cuda()
定义GraphConvolution类:
GNN模型:
由两层GraphConvolution构成
GNN forward 流程图如下所示:
4.3标准化邻阶矩阵
adj_norm = gnn.normalize_adj(A)
4.4实例化编码器和优化器
decoder = nocd.nn.BerpoDecoder(N, A.nnz, balance_loss=balance_loss)
opt = torch.optim.Adam(gnn.parameters(), lr=lr)
5.定义评价函数(预测值和真实值)
def get_nmi(thresh=0.5):
gnn.eval()
Z = F.relu(gnn(x_norm, adj_norm))
Z_pred = Z.cpu().detach().numpy() > thresh
nmi = nocd.metrics.overlapping_nmi(Z_pred, Z_gt)
return nmi
6.训练
val_loss = np.inf
validation_fn = lambda: val_loss
early_stopping = nocd.train.NoImprovementStopping(validation_fn, patience=10)
model_saver = nocd.train.ModelSaver(gnn)
for epoch, batch in enumerate(sampler):
if epoch > max_epochs:
break
if epoch % 25 == 0:
with torch.no_grad():
gnn.eval()
# Compute validation loss
Z = F.relu(gnn(x_norm, adj_norm))
val_loss = decoder.loss_full(Z, A)
print(f'Epoch {epoch:4d}, loss.full = {val_loss:.4f}, nmi = {get_nmi():.2f}')
# Check if it's time for early stopping / to save the model
early_stopping.next_step()
if early_stopping.should_save():
model_saver.save()
if early_stopping.should_stop():
print(f'Breaking due to early stopping at epoch {epoch}')
break
# Training step
gnn.train()
opt.zero_grad()
Z = F.relu(gnn(x_norm, adj_norm))
ones_idx, zeros_idx = batch
if stochastic_loss:
loss = decoder.loss_batch(Z, ones_idx, zeros_idx)
else:
loss = decoder.loss_full(Z, A)
loss += nocd.utils.l2_reg_loss(gnn, scale=weight_decay)
loss.backward()
opt.step()
训练结果输出:
第二次跑: