PyG搭建R-GCN实现链接预测

20 篇文章 35 订阅
6 篇文章 7 订阅

前言

关于链接预测的介绍以及链接预测中数据集的划分请参考:链接预测中训练集、验证集以及测试集的划分(以PyG的RandomLinkSplit为例)

1. 数据处理

导入数据:

path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)

输出如下:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={ num_nodes=20 },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。

由于conference节点没有特征,因此需要预先设置特征:

graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))

所有conference节点的特征都初始化为[1]

利用PyG封装的RandomLinkSplit我们很容易实现数据集的划分:

train_data, val_data, test_data = T.RandomLinkSplit(
        num_val=0.1,
        num_test=0.1,
        is_undirected=True,
        add_negative_train_samples=False,
        disjoint_train_ratio=0,
        edge_types=[('author', 'to', 'paper'), ('paper', 'to', 'term'),
                    ('paper', 'to', 'conference')],
        rev_edge_types=[('paper', 'to', 'author'), ('term', 'to', 'paper'),
                        ('conference', 'to', 'paper')]
    )(graph.to_homogeneous())

最终我们得到train_data, val_data, test_data

输出一下原始数据集和三个被划分出来的数据集:

Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])
Data(node_type=[26128], edge_index=[2, 191654], edge_type=[191654], edge_label=[95827], edge_label_index=[2, 95827])
Data(node_type=[26128], edge_index=[2, 191654], edge_type=[191654], edge_label=[23956], edge_label_index=[2, 23956])
Data(node_type=[26128], edge_index=[2, 215610], edge_type=[215610], edge_label=[23956], edge_label_index=[2, 23956])

从上到下依次为原始数据集、训练集、验证集以及测试集。其中,训练集中一共有95827个正样本,验证集和测试集中均为11978个正样本+11978个负样本。

2. R-GCN链接预测

本次实验使用R-GCN来进行链接预测:首先利用R-GCN对训练集中的节点进行编码,得到节点的向量表示,然后使用这些向量表示对训练集中的正负样本(在每一轮训练时重新采样负样本)进行有监督学习。具体来讲就是将一条边上的两个特征向量进行拼接然后送入一个全连接层进行二分类

2.1 负采样

链接预测训练过程中的每一轮我们都需要对训练集进行采样以得到与正样本数量相同的负样本,验证集和测试集在数据集划分阶段已经进行了负采样,因此不必再进行采样。

负采样函数:

def negative_sample(data):
    # 从训练集中采样与正边相同数量的负边
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index, num_nodes=data.num_nodes,
        num_neg_samples=data.edge_label_index.size(1), method='sparse')
    # print(neg_edge_index.size(1))   # 4488条负边,即每次采样与训练集中正边数量一致的负边
    edge_label_index = torch.cat(
        [data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        data.edge_label,
        data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    return edge_label, edge_label_index

这里用到了negative_sampling方法,其参数有:
在这里插入图片描述
具体来讲,negative_sampling方法利用传入的edge_index参数进行负采样,即采样num_neg_samplesedge_index中不存在的边。num_nodes指定节点个数,method指定采样方法,有sparsedense两种方法。

采样后将neg_edge_index与训练集中原有的正样本train.edge_label_index进行拼接以得到完整的样本集,同时也需要在原本的train_data.edge_label后面添加指定数量的0用于表示负样本。

2.2 模型搭建

R-GCN链接预测模型搭建如下:

class RGCN_LP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(RGCN_LP, self).__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels,
                              num_relations=num_relations, num_bases=30)
        self.conv2 = RGCNConv(hidden_channels, out_channels,
                              num_relations=num_relations, num_bases=30)
        self.lins = torch.nn.ModuleList()
        for i in range(len(node_types)):
            lin = nn.Linear(init_sizes[i], in_channels)
            self.lins.append(lin)

        self.fc = nn.Sequential(
            nn.Linear(2 * out_channels, 1),
            nn.Sigmoid()
        )

    def trans_dimensions(self, xs):
        res = []
        for x, lin in zip(xs, self.lins):
            res.append(lin(x))
        return torch.cat(res, dim=0)

    def encode(self, data):
        x = self.trans_dimensions(init_x)
        edge_index, edge_type = data.edge_index, data.edge_type
        x = F.relu(self.conv1(x, edge_index, edge_type))
        # x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_type)

        return x

    def decode(self, z, edge_label_index):
        # print(z.shape)
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        x = torch.cat([src, dst], dim=-1)
        x = self.fc(x)

        return x

    def forward(self, data, edge_label_index):
        z = self.encode(data)
        return self.decode(z, edge_label_index)

由于DBLP中不同类型的节点具有不同的特征空间,因此我们首先需要将所有节点的特征转换到同一维度:

def trans_dimensions(self, xs):
    res = []
    for x, lin in zip(xs, self.lins):
        res.append(lin(x))
    return torch.cat(res, dim=0)

其中xs为所有节点的特征集合:

init_x = [graph[node_type].x for node_type in node_types]

这里之所以可以将所有类型的x按照顺序进行拼接,是因为train_data等三个数据集中的edge_index是按照节点顺序进行编号的,即author, paper, paper, conference的顺序进行编码。

2.3 模型训练/测试

参考前面:PyG搭建GCN实现链接预测

训练:

def train():
    model = RGCN_LP(in_feats, hidden_feats, 128).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCELoss().to(device)
    min_epochs = 10
    min_val_loss = np.Inf
    final_test_auc = 0
    final_test_ap = 0
    model.train()
    for epoch in tqdm(range(100)):
        optimizer.zero_grad()
        edge_label, edge_label_index = negative_sample(train_data)
        out = model(train_data, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()
        # validation
        val_loss, test_auc, test_ap = test(model, val_data, test_data)
        if epoch + 1 > min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            final_test_auc = test_auc
            final_test_ap = test_ap

        print('epoch {:03d} train_loss {:.8f} val_loss {:.4f} test_auc {:.4f} test_ap {:.4f}'
              .format(epoch, loss.item(), val_loss, test_auc, test_ap))

    return final_test_auc, final_test_ap

测试:

@torch.no_grad()
def test(model, val_data, test_data):
    model.eval()
    # cal val loss
    criterion = torch.nn.BCELoss().to(device)
    out = model(val_data, val_data.edge_label_index).view(-1)
    val_loss = criterion(out, val_data.edge_label)
    # cal metrics
    out = model(test_data, test_data.edge_label_index).view(-1)
    model.train()

    auc, ap = get_metrics(out, test_data.edge_label)

    return val_loss, auc, ap

训练100轮:

final best auc: 0.875985331793656
final best ap: 0.78863051487488

完整代码

代码地址:GNNs-for-Link-Prediction。原创不易,下载时请给个follow和star!感谢!!

  • 12
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值