PyG搭建GCN实现链接预测

20 篇文章 35 订阅
6 篇文章 7 订阅

前言

关于链接预测的介绍以及链接预测中数据集的划分请参考:链接预测中训练集、验证集以及测试集的划分(以PyG的RandomLinkSplit为例)

1. 数据处理

这里以CiteSeer网络为例:Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。

加载数据:

dataset = Planetoid('data', name='CiteSeer')
print(dataset[0])

输出:

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

x=[3327, 3703]表示一共有3327个节点,然后节点的特征维度为3703,这里实际上是去除停用词和在文档中出现频率小于10次的词,整理得到3703个唯一词。edge_index=[2, 9104],表示一共9104条edge,数据一共两行,每一行都表示节点编号。

利用PyG封装的RandomLinkSplit我们很容易实现数据集的划分:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])
dataset = Planetoid('data', name='CiteSeer', transform=transform)
train_data, val_data, test_data = dataset[0]

最终我们得到train_data, val_data, test_data

输出一下原始数据集和三个被划分出来的数据集:

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[3642], edge_label_index=[2, 3642])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])
Data(x=[3327, 3703], edge_index=[2, 8194], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])

从上到下依次为原始数据集、训练集、验证集以及测试集。其中,训练集中一共有3642个正样本,验证集和测试集中均为455个正样本+455个负样本。

2. GCN链接预测

本次实验使用GCN来进行链接预测:首先利用GCN对训练集中的节点进行编码,得到节点的向量表示,然后使用这些向量表示对训练集中的正负样本(在每一轮训练时重新采样负样本)进行有监督学习,具体来讲就是利用节点向量求得样本中节点对的内积,然后与标签求损失,最后反向传播更新参数。

2.1 负采样

链接预测训练过程中的每一轮我们都需要对训练集进行采样以得到与正样本数量相同的负样本,验证集和测试集在数据集划分阶段已经进行了负采样,因此不必再进行采样。

负采样函数:

def negative_sample():
    # 从训练集中采样与正边相同数量的负边
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    # print(neg_edge_index.size(1))   # 3642条负边,即每次采样与训练集中正边数量一致的负边
    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    return edge_label, edge_label_index

这里用到了negative_sampling方法,其参数有:
在这里插入图片描述
具体来讲,negative_sampling方法利用传入的edge_index参数进行负采样,即采样num_neg_samplesedge_index中不存在的边。num_nodes指定节点个数,method指定采样方法,有sparsedense两种方法。

采样后将neg_edge_index与训练集中原有的正样本train.edge_label_index进行拼接以得到完整的样本集,同时也需要在原本的train_data.edge_label后面添加指定数量的0用于表示负样本。

2.2 模型搭建

GCN链接预测模型搭建如下:

class GCN_LP(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        # z所有节点的表示向量
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        # print(dst.size())   # (7284, 64)
        r = (src * dst).sum(dim=-1)
        # print(r.size())   (7284)
        return r

    def forward(self, x, edge_index, edge_label_index):
        z = self.encode(x, edge_index)
        return self.decode(z, edge_label_index)

编码器由一个两层GCN组成,用于得到训练集中节点的向量表示,解码器用于得到训练集中节点对向量间的内积。

由前面可知训练集中的正样本数量为3642,经过负采样函数negative_sample得到3642个负样本,一共7284个样本,最终解码器返回7284个节点对间的内积。

损失函数采用BCEWithLogitsLoss,要想弄懂BCEWithLogitsLoss,就要先了解BCELoss

BCELoss是一种二元交叉熵损失:
在这里插入图片描述

BCEWithLogitsLoss则是在BCELoss的基础上增加了Sigmoid选项,即先把输入经过一个Sigmoid,然后再计算BCELoss

评价指标采用AUC:

roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

2.3 模型训练/测试

代码如下:

def test(model, data):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, data.edge_index)
        out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
        model.train()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


def train():
    model = GCN_LP(dataset.num_features, 128, 64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss().to(device)
    min_epochs = 10
    best_model = None
    best_val_auc = 0
    final_test_auc = 0
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()
        edge_label, edge_label_index = negative_sample()
        out = model(train_data.x, train_data.edge_index, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()
        # validation
        val_auc = test(model, val_data)
        test_auc = test(model, test_data)
        if epoch + 1 > min_epochs and val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc

        print('epoch {:03d} train_loss {:.8f} val_auc {:.4f} test_auc {:.4f}'
              .format(epoch, loss.item(), val_auc, test_auc))

    return final_test_auc

最终测试集上的AUC为:

final best auc: 0.9076681560198044

完整代码

代码地址:GNNs-for-Link-Prediction。原创不易,下载时请给个follow和star!感谢!!

  • 15
    点赞
  • 92
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 12
    评论
GCN (Graph Convolutional Network)是一种用于图数据的机器学习模型,能够利用图结构中的节点和边的信息进行学习和预测任务。PYG (PyTorch Geometric)是基于PyTorch的一个开源库,提供了处理图数据的工具和模型。 GCN链路预测是指利用GCN模型对图数据中不存在的边进行预测,判断这些边在图中是否会存在。这种预测任务在社交网络、生物学、推荐系统等领域具有重要的应用。在PYG中,可以使用其提供的图卷积层和其他模型构建一个GCN链路预测的模型。 在使用PYG进行GCN链路预测时,首先需要构建一个图对象,将节点和边的信息加载到图中。可以使用PYG提供的数据加载器来导入图数据,并将其转换为图对象。然后,需要定义GCN模型的结构,包括图卷积层的设置和激活函数的选择。PYG提供了许多常用的图卷积层和激活函数的实现,可以根据具体任务选择适合的模型结构。 下来,可以使用GCN模型对图数据进行训练和预测。训练阶段,可以使用已知的边来构建训练集,并根据GCN模型的输出与真实标签之间的差异来优化模型参数。预测阶段,可以使用已有的模型对不存在的边进行预测,通常是根据模型输出的概率值或阈值来判断边的存在性。 最后,可以根据预测结果进行评估和分析。常用的评估指标包括准确率、召回率、F1值等,可以通过比较预测结果和真实标签来计算这些指标。此外,还可以通过可视化图数据和GCN模型的注意力机制等来分析模型的学习过程和预测结果。 总之,利用PYG中的GCN模型进行链路预测需要加载图数据、构建模型、训练和预测,并进行评估和分析。通过这一过程,可以预测不存在的边的存在性,为实际应用提供决策和指导。
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值