今天学习的是谷歌大脑的同学 2017 年的工作《Neural Message Passing for Quantum Chemistry》,也就是我们经常提到的消息传递网络(Message Passing Neural Network,MPNN),目前引用数超过 900 次。
严格来说,MPNN 不是一个模型,而是一个框架。作者在这篇论文中主要将现有模型抽象其共性并提出成 MPNN 框架,同时利用 MPNN 框架在分子分类预测中取得了一个不错的成绩。
1.Introduction
深度学习被广泛应用于图像、音频、NLP 等领域,但在化学任务(分子分类等)中仍然使用中机器学习+特征工程的方式,其主要原因在于目前尚未有工作证明深度学习在这个领域能取得很大的成功。
近年来,随着量子化学计算和分子动力学模拟等实验的展开产生了巨大的数据量,大多数经典的技术都无法有效利用目前的大数据集。而原子系统的对称性表明,能够应用于网络图中的神经网络也能够应用于分子模型。所以,找到一个更加强大的模型来解决目前的化学任务可以等价于找到一个适用于网络的模型。
在这篇论文中,作者的目标是证明:能够应用于化学预测任务的模型可以直接从分子图中学习到分子的特征,并且不受到图同构的影响。为此,作者将应用于图上的监督学习框架称之为消息传递神经网络(MPNN),这种框架是从目前比较流行的支持图数据的神经网络模型中抽象出来的一些共性,抽象出来的目的在于理解它们之间的关系。
鉴于目前已经有很多类似 MPNN 框架的模型,所以作者呼吁学者们应该将这个方法应用到实际的应用中,并且通过实际的应用来提出模型的改进版本,尽可能的去推广模型的实际应用。
本文给出的一个例子是利用 MPNN 框架代替计算代价昂贵的 DFT 来预测有机分子的量子特性:
2.MPNN
本节内容分为两块,一块是看下作者如何从现有模型中抽象出 MPNN 框架,另一块是看下作者如何利用 MPNN 框架去解决实际问题。
2.1 MPNN framework
我们先来介绍下 MPNN 这一通用框架,并通过八篇文献来举例验证 MPNN 框架的通配性。
简单起见,我们考虑无向图 G,节点 v 的特征为 x v x_v xv,边的特征为 e v w e_{vw} evw。前向传递有两个阶段:一个是消息传递阶段(Message Passing),另一个是读出阶段(Readout)。考虑消息传递阶段,消息函数定义为 M t M_t Mt,顶点更新函数定义为 U t U_t Ut,t 为运行的时间步。在消息传递过程中,隐藏层节点 v 的状态 h v t h_v^t hvt 可以被基于 m v t + 1 m_v^{t+1} mvt+1 进行更新:
m v t + 1 = ∑ w ∈ N ( v ) M t ( h v t , h w t , e v w ) h v t + 1 = U t ( h v t , m v t + 1 ) \begin{aligned} m_v^{t+1} &= \sum_{w\in N(v)}M_t(h_v^t, h_w^t,e_{vw}) \\ h_v^{t+1} &= U_t(h_v^t, m_v^{t+1}) \end{aligned} \\ mvt+1hvt+1=w∈N(v)∑Mt(hvt,hwt,evw)=Ut(hvt,mvt+1)
其中, N ( v ) N(v) N(v) 表示图 G 中节点 v 的邻居。
读出阶段使用一个读出函数 R 来计算整张图的特征向量:
y ^ = R ( h v T ∣ v ∈ G ) \hat y = R({h_v^T | v \in G}) \\ y^=R(hvT∣v∈G)
消息函数 M t M_t Mt,向量更新函数 U t U_t Ut 和读出函数 R R R 都是可微函数。 R R R 作用于节点的状态集合,同时对节点的排列不敏感,这样才能保证 MPNN 对图同构保持不变。
此外,我们也可以通过引入边的隐藏层状态来学习图中的每一条边的特征,并且同样可以用上面的等式进行学习和更新。
接下来我们看下如何通过定义消息函数、更新函数和读出函数来适配不同种模型。
Paper 1 : Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)
这篇论文中消息函数为:
M ( h v , h w , e v w ) = ( h w , e v w ) M(h_v, h_w,e_{vw}) = (h_w,e_{vw}) \\ M(hv,hw,evw)=(hw,evw)
其中 ( . , . ) (.,.) (.,.) 表示拼接(concat);
节点的更新函数为:
U t ( h v t , m v t + 1 ) = σ ( H t d e g ( v ) m v t + 1 ) U_t(h_v^t,m_v^{t+1}) = \sigma(H_t^{deg(v)}m_v^{t+1}) \\ Ut(hvt,mvt+1)=σ(Htdeg(v)mvt+1)
其中 σ \sigma σ 为 sigmoid 函数, d e g ( v ) deg(v) deg(v) 表示节点 v 的度, H t v H_t^v Htv 是一个可学习的矩阵,t 为时间步,N 为节点度;
读出函数 R 将先前所有隐藏层的状态 h v t h_v^t h