Overlapping Community Detectionwith Graph Neural Networks 论文笔记及代码详解

论文笔记

git链接

1.主题

To summarize, our main contributions are:
Model: We introduce a graph neural network (GNN) based model for overlapping community detection.
Data: We introduce 4 new datasets for overlapping community detection that can act as a benchmark and stimulate future research in this area.
• Experiments: We perform a thorough evaluation of our model and show its superior performance compared to established methods for overlapping community detection, both in terms of speed and accuracy.

2.定义

3.模型(THE NOCD MODEL

创新点:将GNN和Bernoulli–Poisson probabilistic model结合起来

邻接矩阵可以表示为:

Fu是F矩阵的第u行,也可以说是u节点的F向量

模型定义

具化为:

其中:

损失函数设计

Bernoulli–Poisson模型的负对数似然估计:

因为A必定是个稀疏矩阵,(u,v)存在边的概率会很小,所以上面公式的第二项的占比会比较重,所以优化以上函数:

评价

与其他算法在不同数据集上的比较

其中NMI值越高效果越好

为什么要用GNN?(与传统的MLP比较)

MLP结构如下


代码详解

1.数据处理

先将数据处理成Graph,结构如下:

                           

loader = nocd.data.load_dataset('data/mag_cs.npz')
A, X, Z_gt = loader['A'], loader['X'], loader['Z']
N, K = Z_gt.shape

2.设置超参数

hidden_sizes = [128]    # hidden sizes of the GNN
weight_decay = 1e-2     # strength of L2 regularization on GNN weights
dropout = 0.5           # whether to use dropout
batch_norm = True       # whether to use batch norm
lr = 1e-3               # learning rate
max_epochs = 500        # number of epochs to train
display_step = 25       # how often to compute validation loss
balance_loss = True     # whether to use balanced loss
stochastic_loss = True  # whether to use stochastic or full-batch training
batch_size = 20000      # batch size (only for stochastic training)

3.选择(有些特别依赖于标签的网络结构,例如论文引用网络)并归一化特征矩阵(或者输入A,A+X)

x_norm = normalize(X)  # node features
# x_norm = normalize(A)  # adjacency matrix
# x_norm = sp.hstack([normalize(X), normalize(A)])  # concatenate A and X
x_norm = nocd.utils.to_sparse_tensor(x_norm).cuda()

x_norm的类型转换:

4.定义GNN模型

4.1导入数据集

sampler = nocd.sampler.get_edge_sampler(A, batch_size, batch_size, num_workers=2)

dataset[A]分为2**32长度个item,每个item包含batch_size个已连接的边和batch_size个未连接的边,类型为Tensor

4.2GNN模型定义

gnn = nocd.nn.GCN(x_norm.shape[1], hidden_sizes, K, batch_norm=batch_norm, dropout=dropout).cuda()

定义GraphConvolution类:

GNN模型:

由两层GraphConvolution构成

GNN forward 流程图如下所示:

4.3标准化邻阶矩阵

adj_norm = gnn.normalize_adj(A)

4.4实例化编码器和优化器

decoder = nocd.nn.BerpoDecoder(N, A.nnz, balance_loss=balance_loss)
opt = torch.optim.Adam(gnn.parameters(), lr=lr)

5.定义评价函数(预测值和真实值)

NMImax算法原理和详解

def get_nmi(thresh=0.5):
    gnn.eval()
    Z = F.relu(gnn(x_norm, adj_norm))
    Z_pred = Z.cpu().detach().numpy() > thresh
    nmi = nocd.metrics.overlapping_nmi(Z_pred, Z_gt)
    return nmi

6.训练

val_loss = np.inf
validation_fn = lambda: val_loss
early_stopping = nocd.train.NoImprovementStopping(validation_fn, patience=10)
model_saver = nocd.train.ModelSaver(gnn)

for epoch, batch in enumerate(sampler):
    if epoch > max_epochs:
        break
    if epoch % 25 == 0:
        with torch.no_grad():
            gnn.eval()
            # Compute validation loss
            Z = F.relu(gnn(x_norm, adj_norm))
            val_loss = decoder.loss_full(Z, A)
            print(f'Epoch {epoch:4d}, loss.full = {val_loss:.4f}, nmi = {get_nmi():.2f}')
            
            # Check if it's time for early stopping / to save the model
            early_stopping.next_step()
            if early_stopping.should_save():
                model_saver.save()
            if early_stopping.should_stop():
                print(f'Breaking due to early stopping at epoch {epoch}')
                break
            
    # Training step
    gnn.train()
    opt.zero_grad()
    Z = F.relu(gnn(x_norm, adj_norm))
    ones_idx, zeros_idx = batch
    if stochastic_loss:
        loss = decoder.loss_batch(Z, ones_idx, zeros_idx)
    else:
        loss = decoder.loss_full(Z, A)
    loss += nocd.utils.l2_reg_loss(gnn, scale=weight_decay)
    loss.backward()
    opt.step()

训练结果输出:

第二次跑:

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值