图神经网络之四:InMemoryDataset基类与节点预测与边预测

图神经网络之四:InMemoryDataset基类与节点预测与边预测任务实践


引言

这篇是datawhale组队学习之图神经网络第四篇,本笔记主要梳理课程的关键点,以及简单的代码实现。

InMemoryDataset基类

在PyG中,我们通过继承InMemoryDataset类来自定义一个数据可全部存储到内存的数据集类。

首先我们给出InMemoryDataset类表达式:

class InMemoryDataset(
​ root: Optional[str] = None,
​ transform: Optional[Callable] = None,
​ pre_transform: Optional[Callable] = None,
​ pre_filter: Optional[Callable] = None)

  • root:字符串类型,存储数据集的文件夹的路径 :raw_dirprocessed_dir(两个文件夹均为属性方法,自定义文件夹名称)

    raw_dir:用于存储未处理的文件,从网络上下载的数据集原始文件会被存放到这里

    processed_dir处理后的数据被保存到这里,以后从此文件夹下加载文件即可获得Data对象

  • transform:函数类型,一个数据转换函数,它接收一个Data对象并返回一个转换后的Data对象。此函数在每一次数据获取过程中都会被执行。获取数据的函数首先使用此函数对Data对象做转换,然后才返回数据。此函数应该用于数据增广(Data Augmentation)。该参数默认值为None,表示不对数据做转换。

  • pre_transform:函数类型,一个数据转换函数,它接收一个Data对象并返回一个转换后的Data对象。此函数在Data对象被保存到文件前调用。因此它应该用于只执行一次的数据预处理。该参数默认值为None,表示不做数据预处理。

  • pre_filter:函数类型,一个检查数据是否要保留的函数,它接收一个Data对象,返回此Data对象是否应该被包含在最终的数据集中。此函数也在Data对象被保存到文件前调用。该参数默认值为None,表示不做数据检查,保留所有的数据。

通过继承InMemoryDataset类来构造一个我们自己的数据集类,我们需要**实现

  • raw_file_names():这是一个属性方法,返回一个数据集原始文件的文件名列表,数据集原始文件应该能在raw_dir文件夹中找到,否则调用process()函数下载文件到raw_dir文件夹。
  • processed_file_names()。这是一个属性方法,返回一个存储处理过的数据的文件的文件名列表,存储处理过的数据的文件应该能在processed_dir文件夹中找到,否则调用process()函数对样本做处理,然后保存处理过的数据到processed_dir文件夹下的文件里。
  • download(): 下载数据集原始文件raw_dir文件夹。
  • process(): 处理数据保存处理好的数据到processed_dir文件夹下的文件
import os.path as osp

import torch
from torch_geometric.data import (InMemoryDataset, download_url)
from torch_geometric.io import read_planetoid_data

class PlanetoidPubMed(InMemoryDataset):
    r""" 节点代表文章,边代表引用关系。
   		 训练、验证和测试的划分通过二进制掩码给出。
    参数:
        root (string): 存储数据集的文件夹的路径
        transform (callable, optional): 数据转换函数,每一次获取数据时被调用。
        pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。
    """

    # url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
    url = 'https://gitee.com/rongqinchen/planetoid/raw/master/data'
    # 如果github的链接不可用,请使用gitee的链接

    def __init__(self, root, transform=None, pre_transform=None):

        super(PlanetoidPubMed, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, 'processed')

    @property
    def raw_file_names(self):
        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
        return ['ind.pubmed.{}'.format(name) for name in names]

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        for name in self.raw_file_names:
            download_url('{}/{}'.format(self.url, name), self.raw_dir)

    def process(self):
        data = read_planetoid_data(self.raw_dir, 'pubmed')
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)
dataset = PlanetoidPubMed('node_classify/cora')
print(dataset.num_classes)
print(dataset[0].num_nodes)
print(dataset[0].num_edges)
print(dataset[0].num_features)

Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.x
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.y
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://gitee.com/rongqinchen/planetoid/raw/master/data/ind.pubmed.test.index
7
2708
10556
1433

节点预测与边预测任务实践

边预测任务实践中学习代码来源于link_pred.py

import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges

dataset = 'Cora'

#path = osp.join(osp.dirname(osp.realpath('__file__')), '..', 'data', dataset)


path = "node_classify/cora";
print(path)
node_classify/cora
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data)
Data(test_neg_edge_index=[2, 527], test_pos_edge_index=[2, 527], train_neg_adj_mask=[2708, 2708], train_pos_edge_index=[2, 8976], val_neg_edge_index=[2, 263], val_pos_edge_index=[2, 263], x=[2708, 1433])
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, 64).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)


def get_link_labels(pos_edge_index, neg_edge_index):
    num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(num_links, dtype=torch.float, device=device)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels


def train(data):
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1))

    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test(data):
    model.eval()

    z = model.encode(data.x, data.train_pos_edge_index)

    results = []
    for prefix in ['val', 'test']:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']
        link_logits = model.decode(z, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
    return results


best_val_auc = test_auc = 0
for epoch in range(1, 101):
    loss = train(data)
    val_auc, tmp_test_auc = test(data)
    if val_auc > best_val_auc:
        best_val = val_auc
        test_auc = tmp_test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

z = model.encode(data.x, data.train_pos_edge_index)
final_edge_index = model.decode_all(z)

Epoch: 001, Loss: 0.6930, Val: 0.6464, Test: 0.6967
Epoch: 002, Loss: 0.6812, Val: 0.6445, Test: 0.6859
Epoch: 003, Loss: 0.7162, Val: 0.6565, Test: 0.6825
Epoch: 004, Loss: 0.6767, Val: 0.6886, Test: 0.6918
Epoch: 005, Loss: 0.6852, Val: 0.7317, Test: 0.7262
Epoch: 006, Loss: 0.6893, Val: 0.7616, Test: 0.7759
Epoch: 007, Loss: 0.6907, Val: 0.7350, Test: 0.7652
Epoch: 008, Loss: 0.6910, Val: 0.6993, Test: 0.7407
Epoch: 009, Loss: 0.6904, Val: 0.6806, Test: 0.7312
Epoch: 010, Loss: 0.6888, Val: 0.6766, Test: 0.7299
Epoch: 011, Loss: 0.6855, Val: 0.6777, Test: 0.7297
Epoch: 012, Loss: 0.6806, Val: 0.6814, Test: 0.7255
Epoch: 013, Loss: 0.6753, Val: 0.6892, Test: 0.7188
Epoch: 014, Loss: 0.6736, Val: 0.7008, Test: 0.7177
Epoch: 015, Loss: 0.6708, Val: 0.7170, Test: 0.7266
Epoch: 016, Loss: 0.6601, Val: 0.7350, Test: 0.7421
Epoch: 017, Loss: 0.6514, Val: 0.7468, Test: 0.7570
Epoch: 018, Loss: 0.6416, Val: 0.7521, Test: 0.7648
Epoch: 019, Loss: 0.6301, Val: 0.7499, Test: 0.7668
Epoch: 020, Loss: 0.6150, Val: 0.7490, Test: 0.7682
Epoch: 021, Loss: 0.6012, Val: 0.7511, Test: 0.7735
Epoch: 022, Loss: 0.5909, Val: 0.7557, Test: 0.7790
Epoch: 023, Loss: 0.5831, Val: 0.7552, Test: 0.7783
Epoch: 024, Loss: 0.5731, Val: 0.7554, Test: 0.7772
Epoch: 025, Loss: 0.5687, Val: 0.7578, Test: 0.7805
Epoch: 026, Loss: 0.5671, Val: 0.7617, Test: 0.7842
Epoch: 027, Loss: 0.5580, Val: 0.7682, Test: 0.7899
Epoch: 028, Loss: 0.5562, Val: 0.7779, Test: 0.7974
Epoch: 029, Loss: 0.5564, Val: 0.7876, Test: 0.8064
Epoch: 030, Loss: 0.5467, Val: 0.7996, Test: 0.8143
Epoch: 031, Loss: 0.5403, Val: 0.8088, Test: 0.8215
Epoch: 032, Loss: 0.5370, Val: 0.8150, Test: 0.8288
Epoch: 033, Loss: 0.5300, Val: 0.8178, Test: 0.8336
Epoch: 034, Loss: 0.5278, Val: 0.8205, Test: 0.8371
Epoch: 035, Loss: 0.5165, Val: 0.8265, Test: 0.8399
Epoch: 036, Loss: 0.5118, Val: 0.8286, Test: 0.8411
Epoch: 037, Loss: 0.5153, Val: 0.8275, Test: 0.8423
Epoch: 038, Loss: 0.5104, Val: 0.8291, Test: 0.8439
Epoch: 039, Loss: 0.5086, Val: 0.8361, Test: 0.8469
Epoch: 040, Loss: 0.5067, Val: 0.8400, Test: 0.8483
Epoch: 041, Loss: 0.5061, Val: 0.8440, Test: 0.8512
Epoch: 042, Loss: 0.5031, Val: 0.8457, Test: 0.8539
Epoch: 043, Loss: 0.4998, Val: 0.8509, Test: 0.8576
Epoch: 044, Loss: 0.4957, Val: 0.8575, Test: 0.8614
Epoch: 045, Loss: 0.4963, Val: 0.8608, Test: 0.8656
Epoch: 046, Loss: 0.4876, Val: 0.8648, Test: 0.8697
Epoch: 047, Loss: 0.4898, Val: 0.8668, Test: 0.8735
Epoch: 048, Loss: 0.4854, Val: 0.8704, Test: 0.8769
Epoch: 049, Loss: 0.4823, Val: 0.8769, Test: 0.8811
Epoch: 050, Loss: 0.4793, Val: 0.8814, Test: 0.8849
Epoch: 051, Loss: 0.4767, Val: 0.8833, Test: 0.8887
Epoch: 052, Loss: 0.4757, Val: 0.8857, Test: 0.8920
Epoch: 053, Loss: 0.4800, Val: 0.8888, Test: 0.8947
Epoch: 054, Loss: 0.4757, Val: 0.8919, Test: 0.8963
Epoch: 055, Loss: 0.4751, Val: 0.8907, Test: 0.8976
Epoch: 056, Loss: 0.4716, Val: 0.8908, Test: 0.8989
Epoch: 057, Loss: 0.4654, Val: 0.8945, Test: 0.9006
Epoch: 058, Loss: 0.4662, Val: 0.8946, Test: 0.9023
Epoch: 059, Loss: 0.4675, Val: 0.8938, Test: 0.9037
Epoch: 060, Loss: 0.4652, Val: 0.8938, Test: 0.9047
Epoch: 061, Loss: 0.4648, Val: 0.8957, Test: 0.9059
Epoch: 062, Loss: 0.4587, Val: 0.8977, Test: 0.9070
Epoch: 063, Loss: 0.4598, Val: 0.8976, Test: 0.9074
Epoch: 064, Loss: 0.4546, Val: 0.8942, Test: 0.9078
Epoch: 065, Loss: 0.4597, Val: 0.8924, Test: 0.9074
Epoch: 066, Loss: 0.4622, Val: 0.8936, Test: 0.9079
Epoch: 067, Loss: 0.4590, Val: 0.8943, Test: 0.9079
Epoch: 068, Loss: 0.4590, Val: 0.8952, Test: 0.9080
Epoch: 069, Loss: 0.4535, Val: 0.8950, Test: 0.9091
Epoch: 070, Loss: 0.4518, Val: 0.8934, Test: 0.9103
Epoch: 071, Loss: 0.4587, Val: 0.8931, Test: 0.9112
Epoch: 072, Loss: 0.4522, Val: 0.8966, Test: 0.9121
Epoch: 073, Loss: 0.4519, Val: 0.8992, Test: 0.9128
Epoch: 074, Loss: 0.4480, Val: 0.9014, Test: 0.9129
Epoch: 075, Loss: 0.4627, Val: 0.9014, Test: 0.9136
Epoch: 076, Loss: 0.4448, Val: 0.8991, Test: 0.9137
Epoch: 077, Loss: 0.4471, Val: 0.8986, Test: 0.9139
Epoch: 078, Loss: 0.4506, Val: 0.8999, Test: 0.9143
Epoch: 079, Loss: 0.4497, Val: 0.9028, Test: 0.9142
Epoch: 080, Loss: 0.4535, Val: 0.9028, Test: 0.9134
Epoch: 081, Loss: 0.4496, Val: 0.9008, Test: 0.9138
Epoch: 082, Loss: 0.4485, Val: 0.8978, Test: 0.9137
Epoch: 083, Loss: 0.4457, Val: 0.8996, Test: 0.9138
Epoch: 084, Loss: 0.4448, Val: 0.9022, Test: 0.9136
Epoch: 085, Loss: 0.4466, Val: 0.9024, Test: 0.9125
Epoch: 086, Loss: 0.4469, Val: 0.9008, Test: 0.9131
Epoch: 087, Loss: 0.4459, Val: 0.8989, Test: 0.9139
Epoch: 088, Loss: 0.4520, Val: 0.8993, Test: 0.9142
Epoch: 089, Loss: 0.4444, Val: 0.9015, Test: 0.9136
Epoch: 090, Loss: 0.4510, Val: 0.9010, Test: 0.9131
Epoch: 091, Loss: 0.4388, Val: 0.9005, Test: 0.9135
Epoch: 092, Loss: 0.4381, Val: 0.9007, Test: 0.9143
Epoch: 093, Loss: 0.4459, Val: 0.9016, Test: 0.9145
Epoch: 094, Loss: 0.4460, Val: 0.9028, Test: 0.9140
Epoch: 095, Loss: 0.4480, Val: 0.9015, Test: 0.9129
Epoch: 096, Loss: 0.4441, Val: 0.9028, Test: 0.9140
Epoch: 097, Loss: 0.4448, Val: 0.9034, Test: 0.9149
Epoch: 098, Loss: 0.4404, Val: 0.9040, Test: 0.9146
Epoch: 099, Loss: 0.4456, Val: 0.9020, Test: 0.9129
Epoch: 100, Loss: 0.4426, Val: 0.9023, Test: 0.9134

参考资料:


  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值