来源:ICML 2017
论文链接: https://arxiv.org/abs/1704.01212
代码: https://github.com/ifding/graph-neural-networks
这篇论文本身没有提出什么新东西,基于以往的研究,提出了一个消息传递框架MPNN(Message Passing Neural Networks),然后基于这个框架选择了一个较为高效的变体,应用到了chemical prediction中
1 为什么使用神经网络进行化学预测
- 其实现在已经有方法可以预测化学分子的性质了:量子力学模拟方法(DFT)
- 然而这个方法是计算昂贵的
- 因此选择神经网络进行处理
2 MPNNs
MPNN框架的大致内容
- 在节点更新方面,消息传递框架大概可分为两个部分
- 聚合周围节点的信息
- 根据聚合到的信息得到该轮迭代中,自己的表征
- 在图级表征方面,当节点更新迭代完成之后,会有一个图级的读出函数,聚合所有节点的信息,得到图级表征
-
- M,U分别表示消息传递和节点更新的函数(神经网络)
- (1)式表示消息传递步骤,聚合一个点自身、邻居、边进行消息聚合
- (2)式表示节点更新,使用上一次迭代的特征h𝑣𝑇和本次迭代聚合到的信息𝑚𝑣𝑡+1得到h𝑣𝑡+1
-
- 得到图级表示的读出操作
- 需要保证排列不变性
六种具体的变体
1 Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)
-
- f 是一个神经网络,𝑊𝑡 为可学习参数
- 存在问题:最后的消息是节点和边分别求和的
- 该模型实现的消息传递无法识别节点和边之间的相关性
2 Gated Graph Neural Networks (GG-NN), Li et al. (2016)
-
- 可以看成是一种软注意力机制,前面的一项用来计算注意力系数
3 Interaction Networks, Battaglia et al. (2016)
- 这篇论文考虑图中的节点和图结构,同时也考虑每个时间步下的节点级的影响
- 𝑀(h𝑣,h𝑤,𝑒𝑣𝑤)输入: (h𝑣,h𝑤,𝑒𝑣𝑤)为的神经网络
- 更新函数的输入:(h𝑣,𝑥𝑣,𝑚𝑣),其中 𝑥𝑣是一个外部向量,表示对顶点v的一些外部影响
- 图级别的输出: 𝑅=𝑓(𝑣∈𝐺h𝑣𝑇),其中 𝑓是一个神经网络,输入是最终的隐藏层状态的和,𝑇 = 1
4 Molecular Graph Convolutions, Kearnes et al. (2016)
5 Deep Tensor Neural Networks, Schutt et al. (2017)
-
- NN为单层神经网络
6 Laplacian Based Methods, Bruna et al. (2013); Defferrard et al. (2016); Kipf & Welling (2016)
- 即谱图卷积神经网络
3 论文提出来的MPNN的变体
消息传递
- 首先从GGNN开始:
- 采用GGNN的方法,但是把边再参数化了
- 如果节点消息同时依赖于源节点 w 和目标节点 v 的话,网络的消息通道将会得到更有效的利用:
虚拟节点
- 目的:更改消息在整个模型中的传递方式(允许信息在传播阶段长距离传播)。两个方法:
- 为未连接的成对节点添加单独的“虚拟”边类型。这可以作为数据预处理步骤来实现,
- 使用潜在的“主”节点(master node),通过特殊的边来连接到图中任意一个节点。主节点充当了一个全局的暂存空间,每个节点都会在消息传递过程中通过主节点进行读取和写入。同时允许主节点具有自己的节点维度,以及内部更新函数(GRU)的单独权重。其目的同样是为了在传播阶段传播很长的距离
读出函数
- 两种
- GGNN中的读出函数:
- 另外一种,Set2Set模型实现了排列不变性,可以作为一种更好的选择
- GGNN中的读出函数:
Multiple Towers
- 实在不知道怎么翻译,大概意思是,原本的消息聚合需要𝑂𝑛2𝑑2的复杂度,很大,于是将嵌入分为𝑘份,每份单独进行消息聚合,然后最后将其拼接到一起,过一个神经网络,进行变换得到该层的表示:
- 从这个角度看很像多头注意力机制
- 计算复杂度简化为了:On2dk2
4 实验及结果
输入
- 对于分子来说有很多可以提取的特征,比如说原子组成、化学键等:
- 分子图中边的三种表示形式
- Chemical Graph:不考虑距离,邻接矩阵的值是离散的键类型:单键,双键,三键或芳香键
- Distance bins:基于矩阵乘法的消息函数的前提假设是边信息是离散的,将键的距离分为 10 个 bin,比如说 [2,6]中均匀划分 8 个 bin, [0,2]为 1 个 bin,[6,+∞] 为 1 个 bin
- Raw distance feature:同时考虑距离和化学键的特征,邻接矩阵的每个实例都是一个 5 维向量,第一维是距离,其余四维是四种不同的化学键
结果
5 感想
- 文章主要的贡献是将以往的研究抽象出了一个框架,让我们更加清晰地看待问题。可能也为GIN的那篇文章系统分析MPNNs的上限打下了基础。看起来这篇文章没有什么创新点,只是几个框架的组合,但是抽象出一个框架本身就已经是很了不起的一件事了。