图神经网络 | (6) 图分类(SAGPool)实战

近期买了一本图神经网络的入门书,最近几篇博客对书中的一些实战案例进行整理,具体的理论和原理部分可以自行查阅该书,该书购买链接:《深入浅出的图神经网络》

该书配套代码

本节我们通过代码来实现基于自注意力的池化机制(Self-Attention Pooling)。来对图整体进行分类,之前我们是对节点分类,每个节点表示一条数据,学习节点的表示,进而完成分类,本节我们通过自注意力池化机制,得到整个图的表示,进而对全图进行分类(每个图表示一条数据)。

  • 导入必要的包
import os
import urllib
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import scipy.sparse as sp
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
import pickle
import pandas as pd
import torch_scatter #注意:torch_scatter 安装时编译需要用到cuda
from collections import Counter

目录

1. D&D数据

2. 图卷积

3. Self-Attention Pooling

4. ReadOut

5. 图分类模型

6. 程序流程


 

1. D&D数据

D&D是一个包含1178个蛋白质结构的数据集。每个蛋白质被表示为一个图,其中节点是氨基酸,如果两个节点之间的距离小于6埃,那么他们之间会有一条边相连。预测任务是将蛋白质分类为酶和非酶。

  • 数据集下载和预处理
class DDDataset(object):
    #数据集下载链接
    url = "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/DD.zip"
    
    def __init__(self, data_root="data", train_size=0.8):
        self.data_root = data_root
        self.maybe_download() #下载 并解压
        sparse_adjacency, node_labels, graph_indicator, graph_labels = self.read_data()
        #把coo格式转换为csr 进行稀疏矩阵运算
        self.sparse_adjacency = sparse_adjacency.tocsr()
        self.node_labels = node_labels
        self.graph_indicator = graph_indicator
        self.graph_labels = graph_labels
        
        self.train_index, self.test_index = self.split_data(train_size)
        self.train_label = graph_labels[self.train_index] #得到训练集中所有图对应的类别标签
        self.test_label = graph_labels[self.test_index] #得到测试集中所有图对应的类别标签

    def split_data(self, train_size):
        unique_indicator = np.asarray(list(set(self.graph_indicator)))
        #随机划分训练集和测试集 得到各自对应的图索引   (一个图代表一条数据)
        train_index, test_index = train_test_split(unique_indicator,
                                                   train_size=train_size,
                                                   random_state=1234)
        return train_index, test_index
    
    def __getitem__(self, index):
       
        mask = self.graph_indicator == index  
        #得到图索引为index的图对应的所有节点(索引)
        graph_indicator = self.graph_indicator[mask]
        #每个节点对应的特征标签
        node_labels = self.node_labels[mask]
        #该图对应的类别标签
        graph_labels = self.graph_labels[index]
        #该图对应的邻接矩阵
        adjacency = self.sparse_adjacency[mask, :][:, mask]
        return adjacency, node_labels, graph_indicator, graph_labels
    
    def __len__(self):
        return len(self.graph_labels)
    
    def read_data(self):
        #解压后的路径
        data_dir = os.path.join(self.data_root, "DD")
        print("Loading DD_A.txt")
        #从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置)  包含所有图的节点
        adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),
                                       dtype=np.int64, delimiter=',') - 1
        print("Loading DD_node_labels.txt")
        #读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
        node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"), 
                                    dtype=np.int64) - 1
        print("Loading DD_graph_indicator.txt")
        #每个节点属于哪个图
        graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"), 
                                        dtype=np.int64) - 1
        print("Loading DD_graph_labels.txt")
        #每个图的标签 (2分类 0,1)
        graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"), 
                                     dtype=np.int64) - 1
        num_nodes = len(node_labels) #节点数 (包含所有图的节点)
        #通过邻接表生成邻接矩阵  (包含所有的图)稀疏存储节省内存(coo格式 只存储非0值的行索引、列索引和非0值)
        #coo格式无法进行稀疏矩阵运算
        sparse_adjacency = sp.coo_matrix((np.ones(len(adjacency_list)), 
                                          (adjacency_list[:, 0], adjacency_list[:, 1])),
                                         shape=(num_nodes, num_nodes), dtype=np.float32)
        print("Number of nodes: ", num_nodes)
        return sparse_adjacency, node_labels, graph_indicator, graph_labels
    
    def maybe_download(self):
        save_path = os.path.join(self.data_root)
        #本地不存在 则下载
        if not os.path.exists(save_path):
            self.download_data(self.url, save_path)
        #对数据集压缩包进行解压
        if not os.path.exists(os.path.join(self.data_root, "DD")):
            zipfilename = os.path.join(self.data_root, "DD.zip")
            with ZipFile(zipfilename, "r") as zipobj:
                zipobj.extractall(os.path.join(self.data_root))
                print("Extracting data from {}".format(zipfilename))
    
    @staticmethod
    def download_data(url, save_path):
        """数据下载工具,当原始数据不存在时将会进行下载"""
        print("Downloading data from {}".format(url))
        if not os.pat
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值