DW_图深度学习_Task_4

DW_图深度学习_Task_4
学习内容:

  • 数据完全存于内存的数据集类
  • 节点预测与边预测任务实践

学习地址: github/datawhalechina/team-learning-nlp/GNN/

本次学习包括两个部分,构造数据完全存于内存的数据集类以及节点和边的预测任务实现。

数据完全存于内存的数据集类

PyG使用数据的一般过程

  1. 从网络上下载数据原始文件;
  2. 对数据原始文件做处理,为每一个图样本生成一个Data对象
  3. 对每一个Data对象执行数据处理,使其转换成新的Data对象;
  4. 过滤Data对象;
  5. 保存Data对象到文件;
  6. 获取Data对象,在每一次获取Data对象时,都先对Data对象做数据变换(于是获取到的是数据变换后的Data对象)

InMemoryDataset基类

在这里插入图片描述

节点预测与边预测任务实践

节点预测代码结构

  1. 导入SDK

    import os.path as osp
    
    import torch
    import torch.nn.functional as F
    from torch_geometric.data import (InMemoryDataset, download_url)
    from torch_geometric.nn import GATConv, Sequential
    from torch_geometric.transforms import NormalizeFeatures
    from torch_geometric.io import read_planetoid_data
    from torch.nn import Linear, ReLU
    
  2. 定义PlanetoidPubMed数据集类

    class PlanetoidPubMed(InMemoryDataset):
    
        url = 'https://github.com/kimiyoung/planetoid/raw/master/data'	
        def __init__(self, root, split="public", num_train_per_class=20,
                     num_val=500, num_test=1000, transform=None,
                     pre_transform=None):
    
            super(PlanetoidPubMed, self).__init__(root, transform, pre_transform)
            self.data, self.slices = torch.load(self.processed_paths[0])
    
            self.split = split
            assert self.split in ['public', 'full', 'random']
    
            if split == 'full':
                data = self.get(0)
                data.train_mask.fill_(True)
                data.train_mask[data.val_mask | data.test_mask] = False
                self.data, self.slices = self.collate([data])
    
            elif split == 'random':
                data = self.get(0)
                data.train_mask.fill_(False)
                for c in range(self.num_classes):
                    idx = (data.y == c).nonzero(as_tuple=False).view(-1)
                    idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
                    data.train_mask[idx] = True
    
                remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
                remaining = remaining[torch.randperm(remaining.size(0))]
    
                data.val_mask.fill_(False)
                data.val_mask[remaining[:num_val]] = True
    
                data.test_mask.fill_(False)
                data.test_mask[remaining[num_val:num_val + num_test]] = True
    
                self.data, self.slices = self.collate([data])
    
        @property
        def raw_dir(self):
            return osp.join(self.root, 'raw')
    
        @property
        def processed_dir(self):
            return osp.join(self.root, 'processed')
    
        @property
        def raw_file_names(self):
            names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
            return ['ind.pubmed.{}'.format(name) for name in names]
    
        @property
        def processed_file_names(self):
            return 'data.pt'
    
        def download(self):
            for name in self.raw_file_names:
                download_url('{}/{}'.format(self.url, name), self.raw_dir)
    
        def process(self):
            data = read_planetoid_data(self.raw_dir, 'pubmed')
            data = data if self.pre_transform is None else self.pre_transform(data)
            torch.save(self.collate([data]), self.processed_paths[0])
    
        def __repr__(self):
            return '{}()'.format(self.name)
    
    dataset = PlanetoidPubMed(root='data/PlanetoidPubMed/', transform=NormalizeFeatures())
    print('dataset.num_features:', dataset.num_features)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = dataset[0].to(device)
    
  3. 定义网络和训练测试模块

    def train():
        model.train()
        optimizer.zero_grad()  # Clear gradients.
        out = model(data.x, data.edge_index)  # Perform a single forward pass.
        # Compute the loss solely based on the training nodes.
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        return loss
    
    def test():
        model.eval()
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
        test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
        return test_acc
    
    
    class GAT(torch.nn.Module):
        def __init__(self, num_features, hidden_channels_list, num_classes):
            super(GAT, self).__init__()
            torch.manual_seed(12345)
            hns = [num_features] + hidden_channels_list
            conv_list = []
            for idx in range(len(hidden_channels_list)):
                conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))
                conv_list.append(ReLU(inplace=True),)
    
            self.convseq = Sequential('x, edge_index', conv_list)
            self.linear = Linear(hidden_channels_list[-1], num_classes)
    
        def forward(self, x, edge_index):
            x = self.convseq(x, edge_index)
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.linear(x)
            return x
    
  4. 模型训练及测试

    %%time
    model = GAT(num_features=dataset.num_features, hidden_channels_list=[200, 100], num_classes=dataset.num_classes).to(device)
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(1, 201):
        loss = train()
        if epoch % 50 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    
    test_acc = test()
    print(f'Test Accuracy: {test_acc:.4f}')
    

边预测

边预测任务,目标是预测两个节点之间是否存在边。
边预测任务的网络包括三大部分:
Encode, decode 以及用于推理的decode_all部分。

import torch
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

关于边预测任务中数据集的划分问题,需要进一步查资料深入研究。这部分内容明天要查一下相关的学习资料琢磨清楚。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值