【学习小记】数据完整存于内存的数据集类+节点预测与边预测任务实践

本次教程依旧来自DataWhale。GNN组对学习
在此之前先补充一个error:
CUDA error: the provided PTX was compiled with an unsupported toolchain
找了好久,差点重启服务器了。其实是装包的时候,匹配的CUDA版本不对。。。

InMemoryDataset

自定义一个数据可全部存储到内存的数据集类,简而言之自定义数据集时候用。直接看怎么构造一个数据集。这里用的PubMed来改,虽然Planetoid类包含了,但这里做些许修改:

class PlanetoidPubMed(InMemoryDataset):
    r""" 节点代表文章,边代表引用关系。
   		 训练、验证和测试的划分通过二进制掩码给出。
    参数:
        root (string): 存储数据集的文件夹的路径
        transform (callable, optional): 数据转换函数,每一次获取数据时被调用。
        pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。
    """

    url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
    # url = 'https://gitee.com/rongqinchen/planetoid/raw/master/data'
    # 如果github的链接不可用,请使用gitee的链接

    def __init__(self, root, transform=None, pre_transform=None):

        super(PlanetoidPubMed, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property#防止属性被修改
    def raw_dir(self):
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, 'processed')

    @property
    def raw_file_names(self):
        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
        return ['ind.pubmed.{}'.format(name) for name in names]

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        for name in self.raw_file_names:
            download_url('{}/{}'.format(self.url, name), self.raw_dir)

    def process(self):
        data = read_planetoid_data(self.raw_dir, 'pubmed')
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)

感觉这个学习方法还是蛮硬核的哈!

运行流程如下:首先,检查数据原始文件是否已下载;其次,检查数据是否经过处理;接着,检查是否存在处理好的数据。

边预测任务实践

目标:两节点之间是否存在边。

注意:edge_index里存储的是正样本,为了构建任务,需生成一些负样本,即采样一些不存在边的节点对作为负样本边。且数量应该平衡。

PyG中为我们提供了现成的采样负样本边的方法,train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1)。该函数将自动地采样得到负样本,并将正负样本分成训练集、验证集和测试集三个集合。

构造神经网络:

import torch
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

用于做边预测的神经网络主要由两部分组成:其一是编码(encode),它与我们前面介绍的节点表征生成是一样的;其二是解码(decode),它根据边两端节点的表征生成边为真的几率(odds)。decode_all(self, z)用于推理(inference)阶段,我们要对所有的节点对预测存在边的几率。

上面其实没太明白。。。我会回来研究的!!!(继续立flag

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值