GraphSAGE和DiffPool操作实录
简介(真的是很简单的介绍)
GraphSAGE主要是对图(Graph)的节点进行特征表示,特征表示出的节点用于下游任务:如节点分类、图表示等;DiffPool则是对输入的节点特征进行聚合,得到一个聚合的特征用来表示图,通常用于输入一堆图之后进行分类的任务。
实验过程记录
实验过程主要是使用了GraphSAGE的无监督训练得到节点的特征,然后使用得到的特征输入至DiffPool的模型中进行图表示,最后使用有监督的方法对图进行分类。具体的信息来自于两篇论文:Inductive Representation Learning on Large Graphs和Hierarchical Graph Representation Learning with Differentiable Pooling
使用GraphSAGE的操作
一、数据准备
GraphSAGE的输入最多有五个文件,其中①③⑤是进行有监督训练(节点分类)需要的,如果只是需要对节点进行特征表示,输出节点的embedding,文件⑤可以没有。
① -G.json格式的文件,其中存储的信息格式为
{““directed”: false,
“graph”: {},
“nodes”:[存储节点的信息,格式为{“test”: true/false, “id”: “节点名字”, “val”:true/ false}],
”links”:[存储边的信息,格式为{“source”: id, “target”: id}] }
对于所有节点,采用十倍交叉验证的方法来选择test或者val的属性
② -feats.npy[optional]格式的文件,存储各个节点的初始输入特征向量,每行数据对应一个节点的特征向量,文章中的三个数据集都有对应的初始特征的生成方式,如果自己的数据没有初始特征,可以使用身份特征(笔者这里使用的是深度较小的node2vec生成的特征向量,其实只要可以对每个节点生成不同的特征即可,比如one-hot,方法不限)
③ -id_map.json文件,存储节点名和其对应的id
④ -walks.txt[optional]文件,存储每个节点的大小为198的随机游走,可以使用代码中的utils.py文件生成,命令为:python -m graphsage.utils --input 输入文件名 --output 输出文件名
⑤ -class_map.json文件,用于存储节点和其属于的类别,用于有监督分类和无监督分类的检验,如果只是需要输出节点的特征表示,这个文件可以没有
二、 代码简述(无监督训练)
源码来自GitHub,引用自https://github.com/williamleif/GraphSAGE
① 聚合器:GraphSAGE的文章中主要提到了三种聚合器,实现写在aggregators.py中,源码中定义了6种不同的聚合器,分别对应不同的聚合方法。
② 无监督分类:三个脚本文件分别用于分类聚合器产生的节点特征,其中主要用的分类器是线性分类器和随机分类器,然后还有和原始特征分类作比较的部分,源码中写的很详细,此处不多叙述。
三、 输出
使用无监督训练的目标,根据文章中的叙述,是为了使相邻的节点特征差异变小,远离的节点差异变大,训练中使得定义的损失函数尽量的小。在训练结束后,程序会输出结果,格式如下:
其中val.npy文件的每一行都代表一个训练得到的节点特征向量,val.txt的每行存储一个节点,与npy文件中的特征向量相对应。
使用DiffPool的操作
一、 数据准备
(数据整理十分头疼,主要是GraphSAGE和DiffPool的数据并没有对接好,所以需要对格式进行相应的整理)
DiffPool需要输入的数据是:
① _A.txt:存储邻接信息,具体来说,就是把很多个图的节点按顺序编号,然后存入节点之间的邻接信息。
② _graph_indicator.txt:第i行对应的值x表示第i个节点属于第x个图
③ _graph_lables.txt:第i行对应的值x表示第i个图属于第x个类
④ _node_attributes.txt:第i行表示第i个节点的特征向量
⑤ _node_labels.txt:表示节点在各个图中的编号(似乎是),由于笔者自己的实验数据不需要用到这部分编号,也没有看懂叙述,所以无法详细讲啦。
在准备自己的训练数据时,_node_labels.txt文件可以没有,如果没有自己准备,程序会自动生成一个相应的编号用于训练,不怎么会影响结果,_node_attributes.txt在DD数据中也没有,程序会对应生成一个常数特征用于训练
其中DD和ENZYMES都是蛋白质数据,分别有1178和600个图,分别分2类和6类,在使用.sh文件进行训练自己的数据时,需要注意参数—max-nodes改为自己数据中最多节点图的节点数
二、 代码简述
(使用的是聚类方法,应该也是无监督,但是由于我们需要验证分类的准确性,所以还是要输入数据的正确分类,即_graph_lables.txt文件)
源码来自https://github.com/RexYing/diffpool
在代码中,作者使用了三种方法来对比图分类的效果,分别是base、set2set、soft-assign
在训练中也是使用十倍交叉验证来对数据集进行划分,对数据训练10次,每次训练1000轮,得到的结果取平均,最后对训练输出一个曲线图,横轴表示训练轮数,纵轴表示分类准确度。线表示训练数据的准确率,点表示测试数据的分类准确度。
在train.py文件中有两个函数,分别可以用来生成1000个两类各500个BA无标度网络和3000个三类各500个ER随机网络来用于训练图分类。
训练需要设备有独显,笔者使用的是1660Ti显卡,需要安装NVIDA的驱动,但是对于cuda还是不能使用,不知道哪里出现了问题,但是将参数中的cuda设为0依然可以训练(这里就很神奇了)。
------最后写文章和之前实验的时间隔了好久,坑填不完了,先搞完前一大半出来吧~