MPNN消息传递神经网络论文阅读

文章介绍了MPNN方法,用于预测化学分子的量子化学性质,相较于DFT方法速度快、误差小。作者构建了一个基于GG-NN的框架,通过可学习的组件和改进的交互机制,如Messagepassing、Vertexupdating和Readout,实现在QM-9数据集上的卓越性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

最近读了Neural Message Passing for Quantum Chemistry(MPNN arxiv),我导和我一致认为这篇论文有点拉跨,但毕竟是Message Passing Neural Networks的首次提出,所以还是讲一下。本文旨在提供一个high level的idea,并不讨论细节。

任务

给定化学分子,对它的量子化学性质进行预测,DFT方法计算昂贵,本文提出了MPNN方法,速度很快,且误差也很小。

方法

1. 框架

作者从六大类图神经网络中抽象出了一个框架:MPNN,并且以GG-NN作为baseline(原文:We began our exploration of MPNNs around the GG-NN model which we believe to be a strong baseline.)。

该框架分为三个部分:Message passingVertex updatingReadout

M_t是一个函数/神经网络,Message passing就是说用顶点v以及它的邻居w的隐藏状态(hidden state),以及连接它们的边的信息,生成信息m_v^{t+1}

U_t是一个函数/神经网络,Vertex updating就是说用刚才生成的m_v^{t+1},和自己的隐藏状态,更新为h_v^{t+1}

R是一个函数/神经网络,Readout就是说整合所有节点T次迭代(Message passing+Vertex updating)之后的隐藏状态,做一个输出。

2. baseline

作者的思想就是以GG-NN为baseline,然后替换三大部分中的构件,取得一个好的效果。

Message passing:使用可学习的、与边vw相关的矩阵A,去乘节点w的隐藏状态。缺点是没有利用顶点v的隐藏状态。

Vertex updating:GRU是Gated Recurrent Unit,读者可以自行去了解。

Readout:熟悉GLU(Gated Linear Units)的读者会一眼看出这个结构:σ()⊙()。内部的i和j是两个不同的神经网络,i将初始状态也作为输入。但是二者的输出维度必须是一样的,因为要做对应元素相乘⊙。

3. Interaction

f是一个神经网络,它用到了顶点v的隐藏状态。

4. set2set

论文里说的是将元组(h_v^T,x_v)进行线性投影,形成T=\{(h_v^T,x_v)\},然后经过M步计算,得到一个图级的嵌入q_t^*,再将q_t^*输入到一个神经网络中得到最终的输出。

5. multiple towers

为了提高模型的运行速度以及对大图的泛化能力,作者提出了多塔(其实就是MapReduce的思想)。节点属性的维度是d,现在将属性划分为k个维度是d/k的属性,每个分布式分别迭代所有节点的d/k属性,最后再拼接起来输入到一个神经网络中进行readout。

实验

1. 数据集

作者选择了QM-9数据集,内含130462个分子,随机选择其中的10000个作为验证集,10000个作为测试集。分子的性质包括原子类型、质子数等。

2. 步骤

这里放原文的一句话:We focused on trying different message functions, output functions, finding the appropriate input representation, and properly tuning hyperparameters.其实就是炼丹。

结果

1. 超越baseline

在所有的13个性质预测中都取得了最小的误差(Target是13个预测任务,左边是之前的方法,中间是baseline,右侧是作者的方法)。

2. 类似消融实验

Virtual edge虚拟边(对没连接的节点进行连接) ,Master node主节点(与所有的节点相连),set2set就是之前提到的readout阶段的东西。可以看出set2set效果是最好的。

对比joint training(13个任务联合训练)和individual training(对每个任务分别训练),肯定是有针对性效果更好;对比有8个tower和没有,有塔的误差更小,达到了泛化能力强的目标。

结语

有问题欢迎在评论区探讨!

<think>好的,用户想要了解如何用Python实现消息传递网络。首先,我需要明确消息传递网络(MPNN)是什么。根据之前的知识,MPNN是一种用于图神经网络的框架,常用于处理图结构数据。接下来,我需要考虑用户可能的背景,他们可能对图神经网络有一定了解,但需要具体的实现步骤和示例代码。 首先,我应该解释消息传递网络的基本原理,包括消息传递、聚合和更新步骤。然后,提供一个简单的Python示例,使用常见的库如PyTorch Geometric,因为这会简化图神经网络的实现。用户可能已经安装了PyTorch,但可能需要指导安装PyTorch Geometric,所以需要给出安装命令。 接下来,我需要检查是否有相关的引用资料。用户提供的引用中,引用[2]和[3]提到了Python实现消息传递神经网络,但具体到MPNN可能需要结合这些内容。例如,引用[3]中的BP神经网络可能与MPNN的结构不同,但反向传播的思想可能有参考价值。而引用[2]关于Kafka的消息传递可能不直接相关,但用户可能混淆了消息传递网络与消息队列系统,需要明确区分。 然后,我需要构建一个示例代码,包括定义网络结构、数据准备和训练过程。考虑到用户可能需要一个完整的例子,代码应该尽可能简单,同时展示关键步骤。比如,使用PyTorch Geometric的MessagePassing基类,定义消息函数和更新函数,并应用到合成数据上。 同时,需要注意用户提到的“消息传递网络”是否指的是图神经网络中的MPNN,或者有其他定义。这里假设是图神经网络中的MPNN,因此需要确保示例正确反映这一点。可能需要引用相关论文或文档,但用户提供的引用中没有直接相关的,所以可能需要内部知识。 最后,生成相关问题,帮助用户进一步学习,比如优化方法、与其他框架的对比和应用场景。确保问题与主题相关,并引导用户深入理解。 总结步骤:解释MPNN原理,提供PyTorch Geometric示例,指导安装,生成相关问题,并确保引用正确。同时,检查代码是否正确,语法是否合规,比如使用$$...$$格式的数学公式。</think>### 如何用 Python 实现消息传递网络(MPNN消息传递网络(Message Passing Neural Network, MPNN)是一种基于图结构的神经网络框架,常用于处理图数据(如分子结构、社交网络)。其核心思想是通过节点间的消息传递、聚合和更新来学习图的表示。以下是实现步骤和示例: --- ### 1. 消息传递网络原理 - **消息传递**:每个节点将自身特征与邻接节点特征结合生成消息,例如: $$m_{ij} = f(x_i, x_j, e_{ij})$$ 其中 $x_i$ 是节点特征,$e_{ij}$ 是边特征。 - **聚合**:节点收集所有邻接节点的消息并聚合,例如: $$m_i = \sum_{j \in \mathcal{N}(i)} m_{ij}$$ - **更新**:根据聚合后的消息更新节点特征: $$x_i' = g(x_i, m_i)$$ --- ### 2. Python 实现示例(使用 PyTorch Geometric) #### 环境准备 ```bash pip install torch torch-geometric ``` #### 代码实现 ```python import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops class MPNNLayer(MessagePassing): def __init__(self, in_channels, out_channels): super(MPNNLayer, self).__init__(aggr='add') # 聚合方式为求和 self.lin = torch.nn.Linear(in_channels * 2, out_channels) def forward(self, x, edge_index): # 添加自环边(可选) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) return self.propagate(edge_index, x=x) def message(self, x_i, x_j): # 定义消息生成函数:拼接源节点和目标节点特征 return self.lin(torch.cat([x_i, x_j], dim=-1)) def update(self, aggr_out): # 定义更新函数:直接返回聚合结果 return aggr_out # 示例数据 x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) # 3个节点,特征维度1 edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) # 边连接关系 # 初始化MPNN层并运行 mpnn = MPNNLayer(in_channels=1, out_channels=2) output = mpnn(x, edge_index) print("更新后的节点特征:\n", output) ``` --- ### 3. 关键步骤解释 1. **消息生成**:通过 `message` 方法将相邻节点的特征拼接并通过线性层。 2. **消息聚合**:使用 `aggr='add'` 对邻接节点的消息求和。 3. **特征更新**:在 `update` 中直接返回聚合后的结果(可根据需求添加激活函数)。 --- ### 4. 应用场景 - **化学分子属性预测**:预测分子毒性或溶解度[^3]。 - **社交网络分析**:识别社区结构或用户行为模式。 - **推荐系统**:基于用户-商品交互图生成推荐。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Burger~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值