前情回顾
1 图表征学习
图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征。图表征的作用是,基于图表征我们可以做图的预测。论文Learning representations of graph data: a survey中对常见的图表征学习方法进行了简单回顾。
将数据构造为图的形式可以帮助我们以一种系统化的方式研究如何发掘复杂的关系和模式。例如,互联网图展示出了给定网页间高频链接的复杂结构;在自然语言处理领域中,人们有时以树的形式表征文本,理解单词之间的联系,从而推断出句子的意义。
然而,机器学习领域的研究主要关注于向量形式的表征,而真实世界中的数据并不能很轻易地被表征为向量。现实世界场景下复杂图结构的例子包括:生物学网络、计算机网络、传感器网络、社交网络、论文引用网络、电力网络和交通网络。通过使用基于图的表征,我们可以捕获结构化数据的顺序、拓扑、集合和其它关系特性。
对于机器学习来说,神经网络方法和非神经网络方法的主要区别在于学习数据的表征。在机器学习术语中,我们会使用特征一词,而在表征学习术语中,我们关心的是学习数据的最优表征,它有助于下游的机器学习任务。
学习图表征背后的思想是:学习一类映射,这类映射将顶点、子图或整体的图嵌入到低维向量空间中的点上。然后,我们优化这些映射,使他们反映嵌入空间的几何结构,学习到的嵌入可以被用作机器学习任务的向量化输入。
常见的图表征方法有五类,包括:核方法、卷及方法、图神经网络方法、图嵌入方法,以及概率方法。“图表征”指的是通过神经网络计算方法学习到特征,每种学习到的表征都分别对图的拓扑信息进行编码。
基于图同构网络(Graph Isomorphism Network, GIN)的图表征网络是当前最经典的图表征学习网络。
2 图同构网络原理
2.1 图同构网络背景:图同构性测试Weisfeiler-Lehman Test (WL Test)
两个图是同构的,意思是两个图拥有一样的拓扑结构,也就是说,我们可以通过重新标记节点从一个图转换到另外一个图。Weisfeiler-Lehman 图的同构性测试算法,简称WL Test,是一种用于测试两个图是否同构的算法。
WL Test 的一维形式,类似于图神经网络中的邻接节点聚合。WL Test 1)迭代地聚合节点及其邻接节点的标签,然后 2)将聚合的标签散列(hash)成新标签,该过程形式化为下方的公式,
L u h ← hash ( L u h − 1 + ∑ v ∈ N ( U ) L v h − 1 ) L^{h}_{u} \leftarrow \operatorname{hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right) Luh←hash⎝⎛Luh−1+v∈N(U)∑Lvh−1⎠⎞
在上方的公式中, L u h L^{h}_{u} Luh表示节点 u u u的第 h h h次迭代的标签,第 0 0 0次迭代的标签为节点原始标签。
在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。
WL测试不能保证对所有图都有效,特别是对于具有高度对称性的图,如链式图、完全图、环图和星图,它会判断错误。
这一部分将分为图同构判断及图相似性评估两大部分,他们都基于WL子树构成。
论文The weisfeiler-lehman method and graph isomorphism testing中对这一部分进行了详细的讲解。
2.1.1 WL子树(作业)
WL子树是Weisfeiler-Lehman Test (WL Test)中的一个重要内容。
Weisfeiler-Lehman Graph Kernels 方法提出用WL子树核衡量图之间相似性。该方法使用WL Test不同迭代中的节点标签计数作为图的表征向量,它具有与WL Test相同的判别能力。直观地说,在WL Test的第 k k k次迭代中,一个节点的标签代表了以该节点为根的高度为 k k k的子树结构。
如下图所示,与常规树不同,WL中的节点可以重复出现,即边可以无向重复计数。
博客 图表征学习(graph representation learning)中也有直观的说明图
具体到作业案例(如下图)
可以写出:
- 6号节点1-3层的WL子树
- 3号节点1-3层的WL子树
- 5号节点1-3层的WL子树
2.1.2 图同构判断
WL-Test通过逐层查找不同来检查图是否同构。
Weisfeiler-Leman Test 算法通过重复执行以下给节点打标签的过程来实现图是否同构的判断:
- 聚合自身与邻接节点的标签得到一串字符串,自身标签与邻接节点的标签中间用
,
分隔,邻接节点的标签按升序排序。排序的原因在于要保证单射性,即保证输出的结果不因邻接节点的顺序改变而改变。- 标签散列,即标签压缩,将较长的字符串映射到一个简短的标签。
- 给节点重新打上标签。
每重复一次以上的过程,就完成一次节点自身标签与邻接节点标签的聚合。
下图展示了一个图同构判断的案例
值得注意的是,当出现两个图相同节点标签的出现次数不一致时,即可判断两个图不相似。但由于我们重复次数永远有限,故我们无法保证两个图是否同构,但可以确定不同构的图。
2.1.3 图相似性评估
WL Test 算法的一点局限性是,它只能判断两个图的相似性,无法衡量图之间的相似性。要衡量两个图的相似性,我们用WL Subtree Kernel方法。该方法的思想是用WL Test算法得到节点的多层的标签,然后我们可以分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的表征向量的内积,即可作为这两个图的相似性估计,内积越大表示相似性越高。
如图所示
2.2 图同构网络
根据提出图同构网络的论文:How Powerful are Graph Neural Networks? ,能实现判断图同构性的图神经网络需要满足,只在两个节点自身标签一样且它们的邻接节点一样时,图神经网络将这两个节点映射到相同的表征,即映射是单射性的。
可重复集合(Multisets)指的是元素可重复的集合,元素在集合中没有顺序关系。 一个节点的所有邻接节点是一个可重复集合,一个节点可以有重复的邻接节点,邻接节点没有顺序关系。 因此GIN模型中生成节点表征的方法遵循WL Test算法更新节点标签的过程。
在生成节点的表征后仍需要执行图池化(或称为图读出)操作得到图表征,最简单的图读出操作是做求和。由于每一层的节点表征都可能是重要的,因此在图同构网络中,不同层的节点表征在求和后被拼接,其数学定义如下,
h G = CONCAT ( READOUT ( { h v ( k ) ∣ v ∈ G } ) ∣ k = 0 , 1 , ⋯ , K ) h_{G} = \text{CONCAT}(\text{READOUT}\left(\{h_{v}^{(k)}|v\in G\}\right)|k=0,1,\cdots, K) hG=CONCAT(READOUT({hv(k)∣v∈G})∣k=0,1,⋯,K)
采用拼接而不是相加的原因在于不同层节点的表征属于不同的特征空间。 未做严格的证明,这样得到的图的表示与WL Subtree Kernel得到的图的表征是等价的。
则基于图同构网络的图表征学习主要包含以下两个过程:
- 首先计算得到节点表征;
- 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout),得到图的表征(Graph Representation)。
3 图同构网络的实现(代码解读)
在教程中,用的是自顶向下的方式进行的解读,图表征可分为节点表征,图池化得到图表征,最后基于图表征转换来实现对图的预测。
若按计算顺序进行组成,则包含
- 节点嵌入得到节点表征(聚合节点信息及边信息),包含
- 编码
AtomEncoder
实现第0层节点表征 - 图同构卷积层
GINConv
进行逐层节点表征
- 编码
- 图池化得到图表征,基本方式包括求和、求平均、求最大值、求加权和等。
3.1 节点表征
3.1.1 AtomEncoder
及BondEncoder
:第0层节点表征
节点(原子)和边(化学键)的属性都为离散值,它们属于不同的空间,无法直接将它们融合在一起。通过嵌入(Embedding),我们可以将节点属性和边属性分别映射到一个新的空间,在这个新的空间中,我们就可以对节点和边进行信息融合。
AtomEncoder
对原子属性做节点嵌入,其实现调用了ogb
库的get_atom_feature_dims
。
ogb
也叫Open Graph Benchmarking,是用于图机器学习的基准数据集、数据加载器和评估器的集合,可关注在KDD 2021上的OGB-LSC官网。
接下来,我们通过下方的代码中的
AtomEncoder
类,来分析将节点属性映射到一个新的空间是如何实现的:
full_atom_feature_dims
是一个链表list
,存储了节点属性向量每一维可能取值的数量,即X[i]
可能的取值一共有full_atom_feature_dims[i]
种情况,X
为节点属性;- 节点属性有多少维,那么就需要有多少个嵌入函数,通过调用
torch.nn.Embedding(dim, emb_dim)
可以实例化一个嵌入函数;torch.nn.Embedding(dim, emb_dim)
,第一个参数dim
为被嵌入数据可能取值的数量,第一个参数emb_dim
为要映射到的空间的维度。得到的嵌入函数接受一个大于0
小于dim
的数,输出一个维度为emb_dim
的向量。嵌入函数也包含可训练参数,通过对神经网络的训练,嵌入函数的输出值能够表达不同输入值之间的相似性。- 在
forward()
函数中,我们对不同属性值得到的不同嵌入向量进行了相加操作,实现了将节点的的不同属性融合在一起。
BondEncoder
类与AtomEncoder
类是类似的。
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
class AtomEncoder(torch.nn.Module):
"""该类用于对原子属性做嵌入。
记`N`为原子属性的维度,则原子属性表示为`[x1, x2, ..., xi, xN]`,其中任意的一维度`xi`都是类别型数据。full_atom_feature_dims[i]存储了原子属性`xi`的类别数量。
该类将任意的原子属性`[x1, x2, ..., xi, xN]`转换为原子的嵌入`x_embedding`(维度为emb_dim)。
"""
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim) # 不同维度的属性用不同的Embedding方法
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:,i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])
return bond_embedding
3.1.2 GINConv
:图同构卷积层
图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi′=hΘ⎝⎛(1+ϵ)⋅xi+j∈N(i)∑xj⎠⎞
PyG中已经实现了此模块,我们可以通过torch_geometric.nn.GINConv
来使用PyG定义好的图同构卷积层,然而该实现不支持存在边属性的图。在这里我们自己自定义一个支持边属性的GINConv
模块。
由于输入的边属性为类别型,因此我们需要先将类别型边属性转换为边表征。我们定义的GINConv
模块遵循“消息传递、消息聚合、消息更新”这一过程。
在GINConv
中,message()
函数中的x_j + edge_attr
操作执行了节点信息和边信息的融合。
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
3.1.3 节点嵌入的实现
节点嵌入模块基于多层GINConv
实现结点嵌入的计算。
我们逐层计算节点表征,从第
1
层开始到第num_layers
层,每一层节点表征的计算都以上一层的节点表征h_list[layer]
、边edge_index
和边的属性edge_attr
为输入。需要注意的是,GINConv
的层数越多,此节点嵌入模块的感受野(receptive field)越大,结点i
的表征最远能捕获到结点i
的距离为num_layers
的邻接节点的信息。
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子表征
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
3.2 图池化及图表征
首先采用GINNodeEmbedding
模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测。
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding
class GINGraphPooling(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""GIN Graph Pooling Module
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
# 因为预测目标的取值范围就在 (0, 50] 内
return torch.clamp(output, min=0, max=50)
可选的基于结点表征计算得到图表征的方法有
sum
,mean
,max
,attention
和set2set
。PyG中集成的所有的图池化的方法可见于Global Pooling Layers。