介绍
图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征,基于图表征我们可以做图的预测。基于图同构网络(Graph Isomorphism Network, GIN)的图表征网络是当前最经典的图表征学习网络。
基于图同构网络(GIN)的图表征网络的实现
基于图同构网络的图表征学习主要包含以下两个过程:
- 首先计算得到节点表征;
- 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout),得到图的表征(Graph Representation)
接下来逐一介绍:
GINNodeEmbedding Module
此节点嵌入模块基于多层GINConv实现结点嵌入的计算。输入到此节点嵌入模块的节点属性为类别型向量。
- 首先用AtomEncoder对其做嵌入得到第0层节点表征
- 逐层计算节点表征,GINConv层越多,此节点嵌入模块的感受野(receptive field)越大。
图同构卷积层的数学定义如下:
x
i
′
=
h
Θ
(
(
1
+
ϵ
)
⋅
x
i
+
∑
j
∈
N
(
i
)
x
j
)
\mathbf{x}^{\prime}_i = h{\mathbf{_\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum{_{j \in \mathcal{N}(i)}} \mathbf{x}_j \right)
xi′=hΘ((1+ϵ)⋅xi+∑j∈N(i)xj)
增加边属性的GIN代码如下(message()函数中的x_j + edge_attr 操作执行了节点信息和边信息的融合):
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
基于图同构网络的图表征模块(GINGraphRepr Module)
此模块首先采用GINNodeEmbedding模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测。
基于结点表征计算得到图表征的方法有: sum(), mean(), max(), attention(), set2set()
理论分析
理论上,图神经网络在区分图结构方面最高能达到与WL Test一样的能力。
Weisfeiler-Lehman Test (WL Test) 图同构性测试:
WL test的一维形式,类似于图神经网络中的领接节点聚合。
- 迭代地聚合节点及领接节点的标签
- 将聚合的标签散列成新标签
L h u ← hash ( L h − 1 u + ∑ v ∈ N ( U ) L h − 1 v ) L^{h}{u} \leftarrow \operatorname{hash}\left(L^{h-1}{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}{v}\right) Lhu←hash⎝⎛Lh−1u+v∈N(U)∑Lh−1v⎠⎞
在上方的公式中, L h u L^{h}{u} Lhu表示节点 u u u的第 h h h次迭代的标签,第 0 0 0次迭代的标签为节点原始标签。在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。
WL测试不能保证对所有图都有效,特别是对于具有高度对称性的图,如链式图、完全图、环图和星图,它会判断错误。
WL Test 算法举例
给定两个图
G
G
G和
G
′
G^{\prime}
G′,每个节点拥有标签。
聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间用,分隔,邻接节点的标签按升序排序。
将较长的字符串映射到一个简短的标签。
给节点重新打标签。
当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。如果上述的步骤重复一定的次数后,没有发现有相同节点标签的出现次数不一致的情况,那么我们无法判断两个图是否同构。
当两个节点的
h
h
h层的标签一样时,表示分别以这两个节点为根节点的WL子树是一致的。
图相似性评估
衡量两个图的相似性,我们用WL Subtree Kernel方法。该方法的思想是用WL Test算法得到节点的多层的标签,然后我们可以分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的表征向量的内积,即可作为这两个图的相似性估计,内积越大表示相似性越高。
作业
WL子树: