【图神经网络】学习聚合函数 GraphSAGE

本文为图神经网络学习笔记,讲解学习聚合函数 GraphSAGE。欢迎在评论区与我交流 👏

前言

本教程在 PPI(蛋白质网络)数据集上用 Tensorflow 搭建 GraphSAGE 框架中的 MaxPooling 聚合模型实现有监督下的图节点标签预测任务。

GraphSAGE 简介

GraphSAGE 是一种在超大规模图上,利用节点的属性信息高效产生未知节点特征表示归纳式学习框架。GraphSAGE 可以被用来生成节点的低维向量表示,尤其对于具有丰富节点属性的 Graph 效果显著。

目前大多数的框架都是直推式学习模型,即只能够在一张固定的 Graph 上进行表示学习,这样既不能够对那些在训练中未见的节点进行有效的向量表示,也不能够跨图进行节点表示学习。GraphSAGE 作为一种归纳式的表示学习框架,能够利用节点丰富的属性信息有效地生成未知节点的特征表示。

在这里插入图片描述

GraphSAGE的核心思想是通过学习一个对邻居节点进行聚合表示的函数,来产生中心节点的特征表示,而不是学习节点本身的 embedding。它既可以进行监督学习也可以进行无监督学习,GraphSAGE 中的聚合函数有以下几种:

  • Mean Aggregator

    Mean 聚合近似等价 GCN 中的卷积传播操作。具体来说就是对中心节点的邻居节点的特征向量求均值,然后和中心节点特征向量拼接,中间有两次非线性变换。

  • GCN Aggregator

    GCN的归纳式学习版本

  • Pooling Aggregator

    在这里插入图片描述

    先对中心节点的邻居节点表示向量进行一次非线性变换,然后对变换后的邻居表示向量进行池化操作(mean pooling 或者 max pooling),最后将 pooling 所得结果与中心节点的特征表示分别进行非线性变换,并将所得结果进行拼接或者相加从而得到中心节点在该层的向量表示。

  • LSTM Aggregator

    将中心节点的邻居节点随机打乱作为输入序列,将所得向量表示与中心节点的向量表示分别经过非线性变换后拼接,得到中心节点在该层的向量表示。LSTM 本身用于序列数据,而邻居节点没有明显的序列关系,因此输入到 LSTM 中的邻居节点需要随机打乱顺序

以 MaxPooling 聚合方法为例构建 GraphSAGE 模型进行有监督学习下的分类任务。

PPI 数据集

PPI(Protein-protein interaction networks)数据集由 24 个对应人体不同组织的图组成。其中 20 个图用于训练,2 个图用于验证,2 个图用于测试。平均每张图有 2372 个节点,每个节点有 50 个特征。测试集中的图与训练集中的图没有交叉,即在训练阶段测试集中的图是不可见的。每个节点拥有多种标签,标签的种类总共有 121 种。

构建模型

我们使用的核心库是 tf_geometric,借助这个 GNN 库可以方便地导入数据集,预处理图数据以及搭建图神经网络。另外我们还引用了 tf.keras.layers 中的 Dropout 缓解过拟合,以及 sklearn 中的 micro f1_score 函数作为评价指标。

导入库函数:

# coding=utf-8
import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tf_geometric.layers.conv.graph_sage import  MaxPoolingGraphSage
from tf_geometric.datasets.ppi import PPIDataset
from sklearn.metrics import f1_score
from tqdm import tqdm
from tf_geometric.utils.graph_utils import RandomNeighborSampler 

加载数据集,使用 tf_geometric自带的PPI数据集。 tf_geometric 提供了简单的图数据构建接口,只需要传入简单的 Python 数组或 Numpy 数组作为节点特征和邻居表就可以构建自己的数据集,如 GIN

# 使用 tf_geometric 自带的 PPI 数据集,返回划分好的训练集(20),验证集(2),测试集(2)。
train_graphs, valid_graphs, test_graphs = PPIDataset().load_data()

由于每个节点的邻居节点的数目不一,出于计算效率的考虑,我们对每个节点采样一定数量的邻居节点作为之后聚合领域信息时的邻居节点。设定采样数量为 num_sample,如果邻居节点的数量大于 num_sample,采用无放回采样。如果邻居节点的数量小于 num_sample,采用有放回采样,直到所采样的邻居节点数量达到 num_sample

# traverse all graphs
for graph in train_graphs + valid_graphs + test_graphs:
  	# andomNeighborSampler 提前对每张图进行预处理,将相关的图信息与各自的图绑定
    neighbor_sampler = RandomNeighborSampler(graph.edge_index)
    # 模型可能会同时作用在多个图上,要保证每张图的邻居节点在抽样结束后不混淆
    # 将抽样结果与每个 Graph 对象绑定。即将抽样信息保存在“cache"缓存字典中
    graph
  • 15
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值