MPNN消息传递机制框架研究(含代码细致研究)

Gilmer 等 (2017) 提出的 Message Passing Neural Networks (MPNNs) 框架,是一种用于处理图结构数据的通用方法,特别适合于量子化学中的分子属性预测问题。以下是对 MPNN 框架的详细分析:

1. 框架概述

MPNN 框架基于图神经网络的思想,将分子看作由节点(代表原子)和边(代表化学键)组成的图结构。MPNN 包括两个主要阶段:

  • 消息传递阶段(Message Passing Phase):通过节点之间的信息交换来更新节点的隐藏状态。
  • 读出阶段(Readout Phase):聚合图中所有节点的信息以生成整个图的特征向量,用于最终的预测任务。

2. 消息传递阶段

  • 在消息传递阶段,MPNN 通过多次迭代来更新每个节点的隐藏状态。每次迭代包括两个步骤:

    • 消息函数 M t M_t Mt:从节点的邻居节点接收信息。消息计算如下:
      m v t + 1 = ∑ w ∈ N ( v ) M t ( h v t , h w t , e v w ) m_{v}^{t+1} = \sum_{w \in N(v)} M_t(h_v^t, h_w^t, e_{vw}) mvt+1=wN(v)Mt(hvt,hwt,evw)
      其中 N ( v ) N(v) N(v) 是节点 v v v 的邻居集合, h v t h_v^t hvt h w t h_w^t hwt 分别为节点 v v v 和其邻居 w w w 的隐藏状态, e v w e_{vw} evw 是边的特征。
    • 节点更新函数 U t U_t Ut:利用接收到的消息来更新节点的隐藏状态:
      h v t + 1 = U t ( h v t , m v t + 1 ) h_{v}^{t+1} = U_t(h_v^t, m_{v}^{t+1}) hvt+1=Ut(hvt,mvt+1)
  • 消息函数和节点更新函数 可以通过不同的神经网络架构(如 G R U GRU GRU 或 &LSTM&)实现,从而提高信息传递和更新的灵活性。

3. 读出阶段

  • 在完成消息传递后,MPNN 使用 读出函数 R R R 聚合所有节点的隐藏状态来计算整个图的输出特征:
    y ^ = R ( { h v T ∣ v ∈ G } ) \hat{y} = R(\{h_v^T | v \in G\}) y^=R({hvTvG})
    • 读出函数需要对节点的排列具有不变性,以确保图同构(Graph Isomorphism)的不变性。

4. MPNN 的创新点

  • 泛化现有模型:MPNN 框架统一了多种现有的图神经网络模型,如 Gated Graph Neural Networks (GG-NN)、Molecular Graph Convolutions 等,使其在一个通用框架下进行对比和改进。
  • 长距离依赖:通过增加虚拟边和主节点的方式,MPNN 能够更有效地捕捉长距离的节点依赖关系,从而提升对分子复杂结构的建模能力。
  • 高效训练:MPNN 在 QM9 数据集上取得了量子力学属性预测的最新结果,同时相较于传统的 DFT 方法,计算效率提升了约 300,000 倍。

5. 应用与效果

  • 数据集:MPNN 在 QM9 数据集上进行了测试,该数据集包含约 130k 个小分子的 13 项量子力学属性预测任务。
  • 性能表现:MPNN 在 13 个预测任务中的 11 项达到了化学精度(Chemical Accuracy),证明其在量子化学预测中的高效性和准确性。

6. 代码分析

对模型中的关键代码进行分析,代码如下:

6.1数据处理模块

6.1.1 数据集侧写

class MyTransform:  # 自定义数据预处理类,用于保留目标值
    def __call__(self, data):  # 使该类可调用,处理传入的数据
        data = copy.copy(data)  # 创建数据副本,避免修改原数据
        data.y = data.y[:, target]  # 仅保留目标值列
        return data  # 返回处理后的数据

class Complete:  # 自定义补全数据类,生成完整的边连接信息
    def __call__(self, data):  # 使该类可调用,处理传入的数据
        data = copy.copy(data)  # 创建数据副本
        device = data.edge_index.device  # 获取数据所在的设备(CPU或GPU)

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)  # 创建节点的行索引
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)  # 创建节点的列索引

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)  # 重构行索引,使每个节点都连接到所有节点
        col = col.repeat(data.num_nodes)  # 重构列索引,确保每个节点都有连接
        edge_index = torch.stack([row, col], dim=0)  # 组合行列索引,得到完整的边连接信息

        edge_attr = None  # 初始化边的属性
        if data.edge_attr is not None:  # 如果原始数据中包含边的属性
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]  # 计算边的索引
            size = list(data.edge_attr.size())  # 获取边属性的大小
            size[0] = data.num_nodes * data.num_nodes  # 修改边属性的大小
            edge_attr = data.edge_attr.new_zeros(size)  # 创建一个新的全零边属性张量
            edge_attr[idx] = data.edge_attr  # 将原始边属性复制到新的边属性张量

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)  # 去除自循环的边
        data.edge_attr = edge_attr  # 更新数据中的边属性
        data.edge_index = edge_index  # 更新数据中的边连接信息

        return data  # 返回处理后的数据

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')  # 设置数据集的路径
transform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])  # 定义数据预处理流程,先后执行MyTransform、Complete和距离变换
dataset = QM9(path, transform=transform).shuffle()  # 加载并随机打乱QM9数据集

# 标准化目标数据,使其均值为0,标准差为1
mean = dataset.data.y.mean(dim=0, keepdim=True)  # 计算目标数据的均值
std = dataset.data.y.std(dim=0, keepdim=True)  # 计算目标数据的标准差
dataset.data.y = (dataset.data.y - mean) / std  # 标准化目标数据
mean, std = mean[:, target].item(), std[:, target].item()  # 获取目标数据的均值和标准差,并转为标量

# 划分数据集为训练集、验证集和测试集
test_dataset = dataset[:10000]  # 获取前10000个数据作为测试集
val_dataset = dataset[10000:20000]  # 获取10000到20000个数据作为验证集
train_dataset = dataset[20000:]  # 获取剩余的数据作为训练集
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)  # 定义测试集的数据加载器
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)  # 定义验证集的数据加载器
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)  # 定义训练集的数据加载器

从以上代码入手,发现数据集是通过QM9()类来调用的,在vscode上通过设置断点进行代码调试,来观察QM9()的信息,如下所示:
QM9数据集
可知,QM9()类包含了许多信息,通过创建QM9()的实例,并对其命名为dataset,可知

  • d a t a s e t . d a t a dataset.data dataset.data表示了QM9数据集的数据,包括了节点特征(原子特征) x x x,边索引 e d g e _ i n d e x edge\_index edge_index,边特征 e d g e _ a t t r edge\_attr edge_attr,每个分子的回归目标 y y y等等;
  • d a t a s e t . n u m _ c l a s s dataset.num\_class dataset.num_class表示了回归目标的个数,共有19个回归目标需要预测;
  • d a t a s e t . n u m _ e d g e _ f e a t u r e s dataset.num\_edge\_features dataset.num_edge_features表示了边特征的个数,且从 e d g e _ a t t r edge\_attr edge_attr的维度中也看出是4;
  • ……不一一赘述

6.1.2 数据预处理

接下来,研究对数据进行预处理的代码

def process(self) -> None:
    try:
        from rdkit import Chem, RDLogger  # 尝试从rdkit导入所需的模块
        from rdkit.Chem.rdchem import BondType as BT  # 从rdkit导入化学键类型
        from rdkit.Chem.rdchem import HybridizationType  # 从rdkit导入原子杂化类型
        RDLogger.DisableLog('rdApp.*')  # 禁用RDKit日志输出,避免冗长的日志信息
        WITH_RDKIT = True  # 设置标志变量,表示RDKit库已成功导入
    except ImportError:
        WITH_RDKIT = False  # 如果没有安装rdkit,则设置为False

    if not WITH_RDKIT:  # 如果未安装rdkit
        print(("Using a pre-processed version of the dataset. Please "
            "install 'rdkit' to alternatively process the raw data."),
            file=sys.stderr)  # 提示用户安装rdkit以进行原始数据处理

        data_list = fs.torch_load(self.raw_paths[0])  # 加载预处理的torch数据
        data_list = [Data(**data_dict) for data_dict in data_list]  # 将数据字典转换为Data对象

        if self.pre_filter is not None:  # 如果存在预处理过滤器
            data_list = [d for d in data_list if self.pre_filter(d)]  # 过滤数据

        if self.pre_transform is not None:  # 如果存在预处理变换
            data_list = [self.pre_transform(d) for d in data_list]  # 对数据进行变换

        self.save(data_list, self.processed_paths[0])  # 保存处理后的数据
        return  # 结束函数
	#--------------------------------------------------------
	# 以上代码在安装了rdkit库以后可以不关注
	
    # 定义原子类型和化学键类型映射
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}  # 原子类型映射
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}  # 键类型映射

    with open(self.raw_paths[1]) as f:  # 打开目标数据文件
        target = [[float(x) for x in line.split(',')[1:20]]  # 解析目标值,从第二列到第二十列
                for line in f.read().split('\n')[1:-1]]  # 跳过文件的第一行和最后一行
        y = torch.tensor(target, dtype=torch.float)  # 将目标值转换为PyTorch张量
        y = torch.cat([y[:, 3:], y[:, :3]], dim=-1)  # 交换目标值的前后3列
        y = y * conversion.view(1, -1)  # 进行单位转换

    with open(self.raw_paths[2]) as f:  # 打开跳过的索引文件
        skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]  # 解析跳过的索引

    suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False)  # 加载分子数据集,不去除氢原子,且不进行分子标准化

    data_list = []  # 存储处理后的数据列表
    for i, mol in enumerate(tqdm(suppl)):  # 遍历每个分子,显示进度条
        if i in skip:  # 如果当前分子在跳过列表中,跳过处理
            continue

        N = mol.GetNumAtoms()  # 获取分子的原子数

        conf = mol.GetConformer()  # 获取分子的构象
        pos = conf.GetPositions()  # 获取原子的位置坐标
        pos = torch.tensor(pos, dtype=torch.float)  # 转换为PyTorch张量

        # 初始化各个特征列表
        type_idx = []  # 原子类型索引
        atomic_number = []  # 原子序数
        aromatic = []  # 是否芳香
        sp = []  # 是否sp杂化
        sp2 = []  # 是否sp2杂化
        sp3 = []  # 是否sp3杂化
        num_hs = []  # 氢原子数量

        # 遍历分子中的所有原子
        for atom in mol.GetAtoms():
            type_idx.append(types[atom.GetSymbol()])  # 添加原子类型索引
            atomic_number.append(atom.GetAtomicNum())  # 添加原子序数
            aromatic.append(1 if atom.GetIsAromatic() else 0)  # 判断是否芳香
            hybridization = atom.GetHybridization()  # 获取原子的杂化类型
            sp.append(1 if hybridization == HybridizationType.SP else 0)  # 判断是否sp杂化
            sp2.append(1 if hybridization == HybridizationType.SP2 else 0)  # 判断是否sp2杂化
            sp3.append(1 if hybridization == HybridizationType.SP3 else 0)  # 判断是否sp3杂化

        z = torch.tensor(atomic_number, dtype=torch.long)  # 将原子序数转换为张量

        rows, cols, edge_types = [], [], []  # 初始化边列表和边类型列表
        for bond in mol.GetBonds():  # 遍历分子中的所有化学键
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()  # 获取键的起始和终止原子索引
            rows += [start, end]  # 添加边的起始原子索引
            cols += [end, start]  # 添加边的终止原子索引
            edge_types += 2 * [bonds[bond.GetBondType()]]  # 添加边的类型(双向)

        edge_index = torch.tensor([rows, cols], dtype=torch.long)  # 转换为PyTorch张量,表示边的索引
        edge_type = torch.tensor(edge_types, dtype=torch.long)  # 转换为PyTorch张量,表示边的类型
        edge_attr = one_hot(edge_type, num_classes=len(bonds))  # 对边的类型进行one-hot编码

        perm = (edge_index[0] * N + edge_index[1]).argsort()  # 计算边索引的排序
        edge_index = edge_index[:, perm]  # 对边索引进行排序
        edge_type = edge_type[perm]  # 对边类型进行排序
        edge_attr = edge_attr[perm]  # 对边属性进行排序

        row, col = edge_index  # 解包边索引
        hs = (z == 1).to(torch.float)  # 获取氢原子的标志(值为1表示氢原子)
        num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()  # 计算每个原子的氢原子数量

        # 将不同的原子特征进行one-hot编码和堆叠
        x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))  # 原子类型的one-hot编码
        x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()  # 其他原子特征
        x = torch.cat([x1, x2], dim=-1)  # 将所有特征合并

        # 获取分子的其他信息
        name = mol.GetProp('_Name')  # 获取分子的名称
        smiles = Chem.MolToSmiles(mol, isomericSmiles=True)  # 获取分子的SMILES表示

        # 创建Data对象,包含原子特征、边索引、边属性、目标值等信息
        data = Data(
            x=x,
            z=z,
            pos=pos,
            edge_index=edge_index,
            smiles=smiles,
            edge_attr=edge_attr,
            y=y[i].unsqueeze(0),  # 目标值
            name=name,
            idx=i,  # 分子的索引
        )

        if self.pre_filter is not None and not self.pre_filter(data):  # 如果存在预处理过滤器,且不通过过滤器
            continue
        if self.pre_transform is not None:  # 如果存在预处理变换
            data = self.pre_transform(data)  # 对数据进行变换

        data_list.append(data)  # 将处理后的数据添加到数据列表中

    self.save(data_list, self.processed_paths[0])  # 保存处理后的数据列表

在安装rdkit库之后,前半部分代码基本可以忽略不计,直接关注后半部分。

6.1.2.1 数据文件了解和初步处理

首先,模型对原始数据文件进行处理,QM数据集原始数据文件包括三个文件:

  1. qm9.sdf: 分子结构,打开文件,内容如下:(仅展示两个分子)
    qm9.sdf
    具体含义如下:
    参考链接
    qm9.sdf文件描述
  2. qm9.sdf.csv: 分子性质表,文件内容如下:(仅展示前8行内容)
    qm9.sdf.csv
    该表主要包含分子id和19个目标回归指标,详情如下:(sorry,太爱摸鱼所以就没翻译)
    qm9.sdf.csv文件列解释
    总的来说,QM9数据集就是用于预测这些指标的数据集。原论文利用MPNN框架也是做的这件事。
  3. uncharacterized.txt这个文件很奇怪,在代码中是要展示被跳过的分子是哪些,代码中专门有设定 s k i p skip skip变量来连接这个文件,并对该文件中的分子进行跳过,不做处理。
6.1.2.2 数据处理

之后就是在进行分子特征的提取,关键代码如下:

# 将不同的原子特征进行one-hot编码和堆叠
x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))  # 原子类型的one-hot编码
x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()  # 其他原子特征
x = torch.cat([x1, x2], dim=-1)  # 将所有特征合并

这里的 x x x 张量就是前面提取的特征进行concat操作之后得到的分子特征,维度是[ n u m _ n o d e s num\_nodes num_nodes, 11]。
具体细节是如何构建分子特征的,代码已在上方,鱼认为姑且可以算是后话,闲下来时再进行补充,先允许鱼摸会鱼~

6.2 模型具体架构代码

接下来研究模型具体代码,查看其网络架构。

class Net(torch.nn.Module):  # 定义神经网络模型
    def __init__(self):  # 网络初始化函数
        super().__init__()  # 调用父类的初始化方法
        self.lin0 = torch.nn.Linear(dataset.num_features, dim)  # 第一层全连接层,输入特征维度为dataset.num_features,输出为dim

        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))  # 定义NNConv的子网络结构
        self.conv = NNConv(dim, dim, nn, aggr='mean')  # 定义图卷积层,使用子网络nn并采取mean聚合方式
        self.gru = GRU(dim, dim)  # 定义GRU网络,用于序列建模

        self.set2set = Set2Set(dim, processing_steps=3)  # 定义Set2Set池化层,处理步骤为3
        self.lin1 = torch.nn.Linear(2 * dim, dim)  # 定义第二层全连接层
        self.lin2 = torch.nn.Linear(dim, 1)  # 定义输出层,输出一个值

    def forward(self, data):  # 前向传播函数
        print(f"We will start with the following input data.")  # 打印将要处理的输入数据
        print(f"Input data shape: {data.x.shape}")  # 打印输入数据的形状
        # 解释:data.x表示节点特征,shape[0]表示节点数,shape[1]表示特征维度。节点也就是原子,特征包括原子序数、坐标等信息。一个data就是一个分子。

        out = F.relu(self.lin0(data.x))  # 输入数据通过第一层全连接层并应用ReLU激活
        print(f"First Linear layer output shape: {out.shape}")  # 打印第一层输出的形状

        h = out.unsqueeze(0)  # 为GRU输入数据增加一个维度,unsqueeze(0)表示在第0维增加一个维度,比如原来是[64],增加后变为[1, 64]
        print(f"GRU input shape: {h.shape}")  # 打印GRU输入的形状
        
        print()  # 打印空行
        print(f"Next, we will perform 3 NNConv and GRU operations.")  # 打印接下来将进行3次NNConv和GRU操作
        for i in range(3):  # 循环进行3次图卷积和GRU操作
            print(f"Starting operation {i + 1}.")  # 打印开始第i+1次操作

            print(f"NNConv input shape: {out.shape}")  # 打印NNConv输入的形状
            print(f"Edge index shape: {data.edge_index.shape}")  # 打印边连接信息的形状
            print(f"Edge attribute shape: {data.edge_attr.shape}")  # 打印边属性的形状

            m = F.relu(self.conv(out, data.edge_index, data.edge_attr))  # 图卷积层输出
            print(f"NNConv output shape and GRU input(less 1 wnsqueeze): {m.shape}")  # 打印图卷积层输出的形状

            print(f"Hiddent state shape before GRU: {h.shape}")  # 打印GRU输入的形状
            out, h = self.gru(m.unsqueeze(0), h)  # 使用GRU处理数据
            print(f"GRU output shape: {out.shape}")  # 打印GRU输出的形状
            print(f"Hidden state shape: {h.shape}")  # 打印隐藏状态的形状

            out = out.squeeze(0)  # 移除多余的维度
            print(f"GRU output shape after squeeze: {out.shape}")  # 打印移除维度后的形状
            print()  # 打印空行
        print(f"Finshed 3 NNConv and GRU operations.")  # 打印完成3次NNConv和GRU操作

        out = self.set2set(out, data.batch)  # 使用Set2Set池化层处理输出
        print(f"Set2Set output shape: {out.shape}")  # 打印Set2Set输出的形状

        out = F.relu(self.lin1(out))  # 输入通过第二层全连接层并应用ReLU激活
        print(f"Second Linear layer output shape: {out.shape}")  # 打印第二层输出的形状

        out = self.lin2(out)  # 输入通过输出层
        print(f"Output layer output shape: {out.shape}")

        # 得到了最终的输出,out是一个标量,其表示的是模型对输入数据的预测值
        return out.view(-1)  # 将输出展平为一维

可以看到,模型的网络架构和前向函数可以说是非常简单的了,回顾消息传递阶段和读出阶段之前所讲公式,这里的代码就是公式的实现。接下来我将逐步分析网络架构:
MPNN张量计算流图

  1. L i n e a r Linear Linear 线性层进行维度变换,将11维的节点输入特征对齐到64维,映射到高维特征空间以便学习到更复杂的表示;
  2. N N C o n v NNConv NNConv 图卷积,网络中最最关键的一层架构,用于消息传递和聚合。采取 m e a n mean mean 方式聚合每个节点的邻居信息,并对节点特征进行更新;
  3. G R U GRU GRU 门控循环单元,用于序列建模。在消息传递过程中,节点的特征在多个时间步上变化, G R U GRU GRU 用于保持历史信息,使节点特征更新更加稳定并有效地整合之前的状态 G R U GRU GRU 通过其内部的门控机制,可以有效避免梯度消失和爆炸的问题。
  4. S e t 2 S e t Set2Set Set2Set 是一种专门用于处理集合结构的池化方法,它通过多次迭代操作将节点级别的信息整合为图级别的全局表示。 S e t 2 S e t Set2Set Set2Set 池化可以将图中的所有节点的特征聚合成一个固定大小的向量,这个特征向量代表整个图(分子)的特征
  5. 线性层降维并输出,最终输出是一维向量,表示模型对输入分子的预测值。

了解完整体网络架构后,现在开始研究其中的关键组分
在这里插入图片描述
在代码中增加print语句,来查看张量计算流变化,(也可以通过python调试)接下来将结合以上打印内容进行分析

6.2.1 NNConv图卷积

图卷积操作用于聚合每个节点的邻居信息,并对节点特征进行更新,查看NNConv代码,发现其关键函数如下,对其进行研究。

6.2.1.1 propogate函数
 def propagate(
     self,
     edge_index: Adj,  # 输入的图的边索引,可以是稀疏张量或其他表示方式
     size: Size = None,  # 图的大小,默认值为None,表示自动推断大小
     **kwargs: Any,  # 其他传递的参数,用于构造和更新节点嵌入
 ) -> Tensor:
     r"""The initial call to start propagating messages.这是开始传递消息的初始调用。
     
     Args:
         edge_index (torch.Tensor or SparseTensor): 定义图的连接关系/消息传递的稀疏矩阵。
         size ((int, int), optional): 赋值矩阵的大小,默认自动推断。
         **kwargs: 其他需要的数据,用于构造和聚合消息并更新节点嵌入。
     """
     decomposed_layers = 1 if self.explain else self.decomposed_layers  # 确定是否需要进行分解层处理

     # 调用所有前向传播前的钩子函数,钩子函数指的是在前向传播之前对输入进行修改的函数
     for hook in self._propagate_forward_pre_hooks.values():
         res = hook(self, (edge_index, size, kwargs))  # 钩子函数调用,修改输入
         if res is not None:  # 如果钩子函数返回值不为None,更新输入
             edge_index, size, kwargs = res

     mutable_size = self._check_input(edge_index, size)  # 检查并确定输入大小

     # 运行“融合”消息和聚合(如果适用)
     fuse = False  # 初始化是否融合的标志位
     if self.fuse and not self.explain:  # 如果支持融合且不在解释模式下,解释模式就是用于解释模型的模式
         if is_sparse(edge_index):  # 如果输入是稀疏张量
             fuse = True
         elif (not torch.jit.is_scripting()
             and isinstance(edge_index, EdgeIndex)):  # 如果不是在脚本模式且edge_index是EdgeIndex类型
             if (self.SUPPORTS_FUSED_EDGE_INDEX
                     and edge_index.is_sorted_by_col):  # 支持融合且列已排序
                 fuse = True

     if fuse:  # 如果支持融合
         coll_dict = self._collect(self._fused_user_args, edge_index,
                                 mutable_size, kwargs)  # 收集用户输入参数

         msg_aggr_kwargs = self.inspector.collect_param_data(
             'message_and_aggregate', coll_dict)  # 收集用于消息传递和聚合的参数
         for hook in self._message_and_aggregate_forward_pre_hooks.values():
             res = hook(self, (edge_index, msg_aggr_kwargs))  # 调用前置钩子
             if res is not None:
                 edge_index, msg_aggr_kwargs = res
         out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)  # 进行消息传递和聚合操作
         for hook in self._message_and_aggregate_forward_hooks.values():
             res = hook(self, (edge_index, msg_aggr_kwargs), out)  # 调用后置钩子
             if res is not None:
                 out = res

         update_kwargs = self.inspector.collect_param_data(
             'update', coll_dict)  # 收集用于更新节点嵌入的参数
         out = self.update(out, **update_kwargs)  # 更新节点嵌入

     else:  # 否则,分别运行消息传递和聚合函数
         if decomposed_layers > 1:  # 如果有多层分解
             user_args = self._user_args
             decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}  # 找到以'_j'结尾的用户参数
             decomp_kwargs = {
                 a: kwargs[a].chunk(decomposed_layers, -1)  # 将参数在最后一个维度上分块
                 for a in decomp_args
             }
             decomp_out = []  # 存储每层的输出

         for i in range(decomposed_layers):  # 遍历每个分解层
             if decomposed_layers > 1:
                 for arg in decomp_args:
                     kwargs[arg] = decomp_kwargs[arg][i]  # 使用分解后的参数

             coll_dict = self._collect(self._user_args, edge_index,
                                     mutable_size, kwargs)  # 收集输入参数

             msg_kwargs = self.inspector.collect_param_data(
                 'message', coll_dict)  # 收集消息传递参数
             for hook in self._message_forward_pre_hooks.values():
                 res = hook(self, (msg_kwargs, ))  # 调用前置消息钩子
                 if res is not None:
                     msg_kwargs = res[0] if isinstance(res, tuple) else res
             out = self.message(**msg_kwargs)  # 进行消息传递操作
             for hook in self._message_forward_hooks.values():
                 res = hook(self, (msg_kwargs, ), out)  # 调用后置消息钩子
                 if res is not None:
                     out = res

             if self.explain:  # 如果在解释模式下
                 explain_msg_kwargs = self.inspector.collect_param_data(
                     'explain_message', coll_dict)  # 收集解释消息参数
                 out = self.explain_message(out, **explain_msg_kwargs)  # 解释消息

             aggr_kwargs = self.inspector.collect_param_data(
                 'aggregate', coll_dict)  # 收集聚合参数
             for hook in self._aggregate_forward_pre_hooks.values():
                 res = hook(self, (aggr_kwargs, ))  # 调用前置聚合钩子
                 if res is not None:
                     aggr_kwargs = res[0] if isinstance(res, tuple) else res

             out = self.aggregate(out, **aggr_kwargs)  # 聚合消息

             for hook in self._aggregate_forward_hooks.values():
                 res = hook(self, (aggr_kwargs, ), out)  # 调用后置聚合钩子
                 if res is not None:
                     out = res

             update_kwargs = self.inspector.collect_param_data(
                 'update', coll_dict)  # 收集更新参数
             out = self.update(out, **update_kwargs)  # 更新节点嵌入

             if decomposed_layers > 1:  # 如果有多层分解
                 decomp_out.append(out)  # 将当前层的输出添加到分解输出列表中

         if decomposed_layers > 1:  # 如果有多层分解
             out = torch.cat(decomp_out, dim=-1)  # 将所有分解层的输出在最后一个维度上拼接

     for hook in self._propagate_forward_hooks.values():
         res = hook(self, (edge_index, mutable_size, kwargs), out)  # 调用后置传播钩子
         if res is not None:
             out = res

     return out  # 返回最终的传播输出

代码不支持融合,因此只用看后半else部分,分别进行传递和聚合。
奇怪的是,鱼对代码单步调试,发现其中所有的hook都没有运行到。

6.2.1.2 message函数

此外,对代码进行单步调试 ,发现消息传递函数如下

def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:  # 消息传递函数
    '''
    参数:
        x_j (Tensor):邻居节点的特征
        edge_attr (Tensor):边的特征
    返回:
        msg (Tensor):消息    
    '''ni
    weight = self.nn(edge_attr)  # 使用神经网络对边特征进行变换,得到权重

    weight = weight.view(-1, self.in_channels_l, self.out_channels)  

    msg = torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)  # 将邻居节点特征与权重相乘,得到消息

    return msg
6.2.1.3 aggregate函数

聚合函数最终是对以下MeanAggregation类的实现,

class MeanAggregation(Aggregation):
    r"""An aggregation operator that averages features across a set of
    elements.

    .. math::
        \mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|}
        \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.
    """
    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        return self.reduce(x, index, ptr, dim_size, dim, reduce='mean')

reduce函数调用了scatter函数

def reduce(self, x: Tensor, index: Optional[Tensor] = None,
           ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
           dim: int = -2, reduce: str = 'sum') -> Tensor:
    '''
    参数:
        x (Tensor):输入张量
        index (Tensor, optional):索引张量
        ptr (Tensor, optional):指针张量
        dim_size (int, optional):维度大小
        dim (int, optional):聚合维度
        reduce (str):聚合方式(“sum”,“mean”,“min”,“max”)
    '''

    if ptr is not None:
        if index is None or self._deterministic:
            ptr = expand_left(ptr, dim, dims=x.dim())
            return segment(x, ptr, reduce=reduce)

    if index is None:
        raise RuntimeError("Aggregation requires 'index' to be specified")

    return scatter(x, index, dim, dim_size, reduce)

scatter函数具体如下:

def scatter(
    src: Tensor,
    index: Tensor,
    dim: int = 0,
    dim_size: Optional[int] = None,
    reduce: str = 'sum',
) -> Tensor:
    r"""Reduces all values from the :obj:`src` tensor at the indices
    specified in the :obj:`index` tensor along a given dimension
    :obj:`dim`. See the `documentation
    <https://pytorch-scatter.readthedocs.io/en/latest/functions/
    scatter.html>`__ of the :obj:`torch_scatter` package for more
    information.

    Args:
        src (torch.Tensor): The source tensor.
        index (torch.Tensor): The index tensor.
        dim (int, optional): The dimension along which to index.
            (default: :obj:`0`)
        dim_size (int, optional): The size of the output tensor at
            dimension :obj:`dim`. If set to :obj:`None`, will create a
            minimal-sized output tensor according to
            :obj:`index.max() + 1`. (default: :obj:`None`)
        reduce (str, optional): The reduce operation (:obj:`"sum"`,
            :obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`,
            :obj:`"any"`). (default: :obj:`"sum"`)
    """
    if isinstance(index, Tensor) and index.dim() != 1:
        raise ValueError(f"The `index` argument must be one-dimensional "
                         f"(got {index.dim()} dimensions)")

    dim = src.dim() + dim if dim < 0 else dim

    if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):
        raise ValueError(f"The `dim` argument must lay between 0 and "
                         f"{src.dim() - 1} (got {dim})")

    if dim_size is None:
        dim_size = int(index.max()) + 1 if index.numel() > 0 else 0

    # For now, we maintain various different code paths, based on whether
    # the input requires gradients and whether it lays on the CPU/GPU.
    # For example, `torch_scatter` is usually faster than
    # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
    # on CPU.
    # `torch.scatter_reduce` has a faster forward implementation for
    # "min"/"max" reductions since it does not compute additional arg
    # indices, but is therefore way slower in its backward implementation.
    # More insights can be found in `test/utils/test_scatter.py`.

    size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]

    # For "any" reduction, we use regular `scatter_`:
    if reduce == 'any':
        index = broadcast(index, src, dim)
        return src.new_zeros(size).scatter_(dim, index, src)

    # For "sum" and "mean" reduction, we make use of `scatter_add_`:
    if reduce == 'sum' or reduce == 'add':
        index = broadcast(index, src, dim)
        return src.new_zeros(size).scatter_add_(dim, index, src)

    if reduce == 'mean':
        count = src.new_zeros(dim_size) # 初始化一个全0的tensor,大小为dim_size
        count.scatter_add_(0, index, src.new_ones(src.size(dim))) # 将index中的元素作为索引,将src中的元素作为值,将1加到count中
        count = count.clamp(min=1) # 将count中的元素限制在1以上

        index = broadcast(index, src, dim) # 将index扩展到src的维度
        out = src.new_zeros(size).scatter_add_(dim, index, src) # 将src中的元素根据index的值,加到out中

        return out / broadcast(count, out, dim) # 将out中的元素除以广播后的count,广播函数把count扩展到out的维度

    # For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
    # in case the input does not require gradients:
    if reduce in ['min', 'max', 'amin', 'amax']:
        if (not torch_geometric.typing.WITH_TORCH_SCATTER
                or is_compiling() or is_in_onnx_export() or not src.is_cuda
                or not src.requires_grad):

            if (src.is_cuda and src.requires_grad and not is_compiling()
                    and not is_in_onnx_export()):
                warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
                              f"can be accelerated via the 'torch-scatter'"
                              f" package, but it was not found")

            index = broadcast(index, src, dim)
            if not is_in_onnx_export():
                return src.new_zeros(size).scatter_reduce_(
                    dim, index, src, reduce=f'a{reduce[-3:]}',
                    include_self=False)

            fill = torch.full(  # type: ignore
                size=(1, ),
                fill_value=src.min() if 'max' in reduce else src.max(),
                dtype=src.dtype,
                device=src.device,
            ).expand_as(src)
            out = src.new_zeros(size).scatter_reduce_(
                dim, index, fill, reduce=f'a{reduce[-3:]}',
                include_self=True)
            return out.scatter_reduce_(dim, index, src,
                                       reduce=f'a{reduce[-3:]}',
                                       include_self=True)

        return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
                                     reduce=reduce[-3:])

    # For "mul" reduction, we prefer `scatter_reduce_` on CPU:
    if reduce == 'mul':
        if (not torch_geometric.typing.WITH_TORCH_SCATTER
                or is_compiling() or not src.is_cuda):

            if src.is_cuda and not is_compiling():
                warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
                              f"can be accelerated via the 'torch-scatter'"
                              f" package, but it was not found")

            index = broadcast(index, src, dim)
            # We initialize with `one` here to match `scatter_mul` output:
            return src.new_ones(size).scatter_reduce_(
                dim, index, src, reduce='prod', include_self=True)

        return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
                                     reduce='mul')

    raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")

以上scatter函数可以不用看,因为关键实现框架mean聚合的代码只有以下部分

if reduce == 'mean':
    count = src.new_zeros(dim_size) # 初始化一个全0的tensor,大小为dim_size
    count.scatter_add_(0, index, src.new_ones(src.size(dim))) # 将index中的元素作为索引,将src中的元素作为值,将1加到count中
    count = count.clamp(min=1) # 将count中的元素限制在1以上

    index = broadcast(index, src, dim) # 将index扩展到src的维度
    out = src.new_zeros(size).scatter_add_(dim, index, src) # 将src中的元素根据index的值,加到out中

    return out / broadcast(count, out, dim) # 将out中的元素除以广播后的count,广播函数把count扩展到out的维度

6.2.2 Set2Set全局池化

class Set2Set(Aggregation):  # 定义Set2Set类,继承自Aggregation类,表示一种基于迭代内容的注意力聚合操作
    r"""Set2Set聚合操作,基于迭代内容式注意力机制,详见论文
    `"Order Matters: Sequence to sequence for Sets" <https://arxiv.org/abs/1511.06391>`_

    数学公式:
        \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})  # q_t由LSTM生成,输入为上一步的q_star

        \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)  # 计算注意力权重,基于节点特征和q_t的点积

        \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i  # 加权求和节点特征,得到r_t

        \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t  # 将q_t和r_t拼接,得到q_star

    Args:
        in_channels (int): 输入特征的大小。
        processing_steps (int): 迭代次数T。
        **kwargs (optional): 额外的参数,传递给torch.nn.LSTM。
    """
    def __init__(self, in_channels: int, processing_steps: int, **kwargs):
        super().__init__()  # 调用父类初始化方法
        self.in_channels = in_channels  # 保存输入特征的维度
        self.out_channels = 2 * in_channels  # 输出特征的维度是输入的两倍,用于拼接q_t和r_t
        self.processing_steps = processing_steps  # 迭代次数T
        self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs)  # 定义LSTM,用于更新q_t和r_t
        self.reset_parameters()  # 初始化LSTM参数

    def reset_parameters(self):
        self.lstm.reset_parameters()  # 重置LSTM层的参数

    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        # 前向传播函数,计算Set2Set聚合操作

        self.assert_index_present(index)  # 确保索引存在
        self.assert_two_dimensional_input(x, dim)  # 确保输入x是二维的

        # 初始化LSTM的隐藏状态和q_star,分别为零张量
        h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))),
             x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))))  # LSTM的初始隐藏状态
        q_star = x.new_zeros(dim_size, self.out_channels)  # 初始化q_star为零张量

        for _ in range(self.processing_steps):  # 迭代指定次数
            q, h = self.lstm(q_star.unsqueeze(0), h)  # 通过LSTM更新q_t和隐藏状态h
            q = q.view(dim_size, self.in_channels)  # 重塑q_t的形状
            e = (x * q[index]).sum(dim=-1, keepdim=True)  # 计算每个节点的注意力得分
            a = softmax(e, index, ptr, dim_size, dim)  # 对得分应用softmax,得到注意力权重
            r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='sum')  # 根据注意力权重聚合邻居节点的特征
            q_star = torch.cat([q, r], dim=-1)  # 将q_t和r_t拼接,得到新的q_star

        return q_star  # 返回最终的q_star,表示图的全局特征

    def __repr__(self) -> str:
        # 重写__repr__方法,用于打印Set2Set对象的简洁表示
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')  # 返回类名和输入输出特征的维度

Set2Set数学公式如下:
q t = L S T M ( q t − 1 ∗ ) q t 由LSTM生成,输入为上一步的 q t − 1 ∗ q_t = LSTM({q}^{*}_{t-1}) \quad q_t \text{由LSTM生成,输入为上一步的} q^*_{t-1} qt=LSTM(qt1)qtLSTM生成,输入为上一步的qt1
α i , t = s o f t m a x ( x i ⋅ q t ) 计算注意力权重,基于节点特征和 q t 的点积 \alpha_{i,t} = \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) \quad \text{计算注意力权重,基于节点特征和} q_t \text{的点积} αi,t=softmax(xiqt)计算注意力权重,基于节点特征和qt的点积
r t = ∑ i = 1 N α i , t x i 加权求和节点特征,得到 r t {r}_t = \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i \quad \text{加权求和节点特征,得到} r_t rt=i=1Nαi,txi加权求和节点特征,得到rt
q t ∗ = q t   ∥   r t 将 q t 和 r t 拼接,得到 q t ∗ {q}^{*}_t = {q}_t \, \Vert \, {r}_t \quad \text{将} q_t \text{和} r_t \text{拼接,得到} q^*_t qt=qtrtqtrt拼接,得到qt
总结来说,MPNN 提出了一个通用且灵活的框架,用于对图结构数据进行高效的特征学习和预测,尤其在量子化学和材料科学领域具有巨大的应用潜力。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值