《Graph U-Nets》阅读笔记
前言
U-Net 永远的神!!!!!!!碰巧上学期期末的时候,用U-Net做了一个肿瘤分割的作业,对U-Net了解了一些,所以这篇读
起来并不是那么费劲。本文把U-Net的思想引入到了Graph里,设计了gPool和gUnpool结构,分别对应U-Net中的下采样和上
采样操作。
论文链接:http://proceedings.mlr.press/v97/gao19a/gao19a.pdf
代码链接:https://github.com/HongyangGao/Graph-U-Nets
一、U-Net
先来瞅瞅U-Net长什么样:
U-Net就不做解释了,放一个论文链接:https://arxiv.org/pdf/1505.04597.pdf
二、Graph U-Net
Graph U-Nets长这个样子,和U-Net简直一毛一样啊
这里的GCN不就是U-Net里的conv吗? gPool不就是max pool吗? gUnpool不就是up-conv吗? 橙色箭头不就是copy and crop吗?
可Graph是非欧数据,不能直接把U-Net的下采样和上采样操作应用到非欧数据上。因此,很有必要设计两个结构来实现Graph的下采样和上采样操作。
1.gPool
目标:实现Graph的下采样操作,保留少些比较能代表原Graph的节点。
输入:Gl = (Al, Xl),分别为第l层的邻接矩阵和特征矩阵,Al ∈ Rn×n,Xl ∈ Rn×C。
输出:Gl+1 = (Al+1, Xl+1),分别为第l+1层的邻接矩阵和特征矩阵,Al+1 ∈ Rk×k,Xl+1 ∈ Rk×C, k < n。
步骤如下:
-
设置一个可训练的投影向量p用于计算信息量 。
-
根据节点的特征,计算每个节点能够保留的信息量y,
在p方向上投影的值越大,节点能够保留的信息越多。 -
对投影的值进行排序,选出前k个最大值,并记下它们在原数组中的下标idx
-
构建邻接矩阵:
将这k个节点原来的邻居关系抽取出来组成下一层的邻接矩阵
注意:作者考虑到池化后的graph可能存在孤立节点的情况,从而影响了信息的传递。所以作者又重新设计了池化时的邻接矩阵,先将节点与其h跳之内的邻居都连接起来以增加连接性,再进行池化。本文中h=2,修改后的邻接矩阵为
-
构建特征矩阵:
先从原特征矩阵中抽取这k个节点对应的特征向量
随后通过一个门操作得到最终的特征矩阵
y(idx)代表从y中提取下标为idx的值;⊙则是哈达玛积。注:原文中说这个门操作可以控制信息流,并且使得投影向量p可以通过BP算法训练。如果没有门操作的话,计算节点的信息量时会造成数据离散的情况。这里不是特别能理解。
2.gUnpool
目标:实现Graph的上采样操作,恢复出上一层的结构。
输入:Gl = (Al, Xl),Al ∈ Rk×k,Xl ∈ Rk×C。
输出:Gl+1 = (Al+1, Xl+1),Al+1 ∈ Rn×n,Xl+1 ∈ Rn×C, n > k。
步骤如下:
- 创建一个全0矩阵,其维度为Rn×C。
- 根据此前池化时记录的下标idx,把节点的特征向量,分配到全0矩阵对应的行上去以构成上采样后的特征矩阵
例如,某次池化过程所记录的idx = [1,2,4,7],则依次把 Xl+1 ∈ R4×C 的行向量填充到全0矩阵的第1,2,4,7行中。 - 邻接矩阵 Al+1 与先前做池化时的邻接矩阵相同,参考下图左上方和右上方的两个graph。
3.整体框架
- 先对原graph做一次GCN,可用作特征降维。
- 做两次(gPool + GCN) 用于减少节点数,并且提取高阶特征。
- 做两次(gUnpool + GCN) 用于还原出最初的拓扑结构。在上采样中,每次做GCN之前均有一个skip connection的操作,将同stage的节点特征(左侧)与当前的节点特征(右侧)做连接(可以是加和、拼接等操作)。
注意:作者稍微改动了一下GCN的信息传递函数,作者希望在信息传递时,自身节点的特征所占的比重更大。
三、实验结果
1.数据集
用于节点分类的数据集(直推式)
用于图分类的数据集(归纳式)
2.性能测试
节点分类的对比实验
其实相比于GAT,这个模型的性能并不是说提高特别多。作者认为如果用更好一点的图卷积层,比如GAT,作为g-U-Nets的信息传递层的话,效果会更好。
图分类的对比试验
这里的DiffPool-DET在COLLAB数据集上的表现比本文提出的模型要好。另外,作者认为DiffPool在训练的时候用链路预测作为辅助任务以稳定模型的性能,这就说明DiffPool它不稳定…(学到了一个写作技巧)。
3.消融实验
这里作者去掉了gPool和gUnpool结构,只保留了skip connection操作,以研究这两个结构对性能的提升程度。
性能提升的理由:加入gPool和gUnpool后,模型能分析更高阶的特征,从而提高了泛化能力和性能。
4.连通性对性能的影响
前面说作者重新设计了池化后的邻接矩阵,将节点与其h跳之内的邻居都连接起来以增加连接性。这里就研究了一下这个操作对性能的影响有多大。
这肯定好的啦,不好都不会写上来的。边增加了,有利于节点之间信息的传递。
5.模型深度对性能的影响
GNN模型有一个通病就是模型的层数不能叠太多。一般叠2-4层的时候,性能会达到最大值,随后就下降了。层数多了,很容易造成过平滑和过拟合的现象。这篇文章所提的模型也不例外。
之前看到过一篇文章,它研究了数据集中的节点在几跳之后能遍历到整张图,我看它给出的结果一般都是6-7跳。所以我觉得这个模型深度能到多少,其实和数据集也挺有关系的。
6.参数量对性能的影响
参数量越大越容易过拟合咯。
参数量只增加了0.12%,但是性能提升了2.3%。