相关资源
论文:https://papers.nips.cc/paper/2020/file/94aef38441efa3380a3bed3faf1f9d5d-Paper.pdf
代码:https://github.com/tencent-ailab/grover
这篇文章发表于NIPS 2020,提出了一种借鉴Transformer的图神经网络,且涉及了两个自监督任务来预训练模型。整个模型称为GROVER。
Motivation
两个问题阻碍了GNNs在实际场景中的使用:
- 带标签的分子数较少,远远不够用于监督学习;
- 没有很好的泛化性,不能泛化到其他新合成分子的学习上。
Model
Model Architecture
论文将GNN Transformer简称为GTransformer,包括node GTransformer和edge GTransformer,由于两部分结构相似,这里只介绍node的。
GTransformer的策略很简单,主要分为两个level的信息提取:
- 消息传递过程能够捕捉局部的结构信息,因此使用GNN的输出作为query、key和value,可以得到局部的子图结构的信息;
- 另外,Transformer编码器可以看成GAT在全连接图上的变体,因此作者认为在所有query、keys和value之上,再进行一次Transformer编码,可以得到全局的图结构信息。
另外,模型使用一个long-range residual connection,直译就是长期残差连接,从输入层直接被送到GTransformer的最后一层,该连接传递了初始的节点/边信息,在其他网络中一般会用多个短期的残差连接。本文的长期残差连接的好处是:1) 能够减轻梯度消失问题;2) 相比多个短期残差连接,本文的长期残差能够缓解消息传递过程中的过平滑问题(over-smoothing)。
总体模型图如下:
到这里还没有介绍文中用到的消息传递机制,文中改进了GNN,用dyMPN描述,具体地,对于第
l
l
l次迭代,第
k
k
k个hop可以写成:
m
v
(
l
,
k
)
=
AGGREGATE
(
l
)
(
(
h
v
(
l
,
k
−
1
)
,
h
u
(
l
,
k
−
1
)
,
e
u
v
)
∣
u
∈
N
v
)
m_v^{(l,k)}=\text{AGGREGATE}^{(l)}({(h_v^{(l,k-1)},h_u^{(l,k-1)},e_{uv})|u\in N_v})
mv(l,k)=AGGREGATE(l)((hv(l,k−1),hu(l,k−1),euv)∣u∈Nv)
h
v
(
l
,
k
)
=
σ
(
W
(
l
)
m
v
(
l
,
k
)
+
b
(
l
)
)
h_v^{(l,k)}=\sigma(W^{(l)}m_v^{(l,k)}+b^{(l)})
hv(l,k)=σ(W(l)mv(l,k)+b(l))
假设一共有
L
L
L层,第
l
l
l层有
K
l
K_l
Kl次hops。其中
m
v
(
l
,
k
)
m_v^{(l,k)}
mv(l,k)是聚合的消息,聚合了节点邻域以及与邻域节点相连的边的关系,且
l
l
l层的节点
v
v
v初始状态
h
v
(
l
,
0
)
:
=
h
v
(
l
−
1
,
K
l
−
1
)
h_v^{(l,0)}:=h_v^{(l-1,K_{l-1})}
hv(l,0):=hv(l−1,Kl−1)。作者认为number of hops可以理解成几阶邻域的概念,与图卷积的感受野很相似,这个参数将直接影响消息传递模型的泛化能力。
为了能够兼容不同的数据集, K l K_l Kl的取值不是通过预训练得到,而是通过一个随机策略。两个策略选择:1) K l ∼ U ( a , b ) K_l\sim U(a,b) Kl∼U(a,b),也就是符合均匀分布;2) K l ∼ ϕ ( μ , σ , a , b ) K_l\sim \phi (\mu,\sigma,a,b) Kl∼ϕ(μ,σ,a,b),也就是符合截断标准正态分布,关于截断标准正态,可以看一下知乎这个问答。
因为上述的随机消息传递机制保证了每个节点随机的感受野,文章叫这个模型为Dynamic Message Passing Networks。
Self-supervised Task
作者认为,一个好的自监督任务需要具备两个条件:1)预测目标是可靠的,且容易获得;2)预测目标需要反应节点/边的内容信息。
- Contextual Property Prediction:
这里论文定义了一个统计属性,该属性的计算如下:1)给定一个目标节点,提取它的局部k-hop的邻居节点和边。2)提取该局部图的统计属性,也就是node-edge-counts,比如下图,目标节点是C(碳原子),如果k=1,也就是一阶邻域,那么将会采样到N和O,最后得到该局部图的该统计属性为C_N-DOUBLE1_O-SINGLE1。
给定一个分子图,将其输入GROVER的编码器后,我们获取了它的每个节点和边的嵌入表示,假设节点 v v v的嵌入表示为 h v h_v hv,可以用这个嵌入输入一个前馈网络,来得到节点 v v v的Contextual Property的预测值。 - Graph-level Motif Prediction:
Motifs是一个recurrent子图,在分子中,一种重要的motifs类别就是官能团,可以很容易通过专业软件检测到,它编码了丰富的分子领域信息。motif预测问题可以看成一个多分类问题。