图神经网络二:消息传递图神经网络

 

引言

这篇是datawhale组队学习之图神经网络第二篇,本笔记主要梳理课程的关键点,以及简单的代码实现。

  • 首先我们将学习图神经网络生成节点表征的范式–消息传递(Message Passing)范式

  • 接着我们将初步分析PyG中的MessagePassing基类,通过继承此基类我们可以方便地构造一个图神经网络。

  • 然后我们以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来构造图神经网络。

  • 再接着我们将对MessagePassing基类进行剖析。

  • 最后我们将学习在继承MessagePassing基类的子类中覆写message(),aggreate(),message_and_aggreate()update(),这些方法的规范。

消息传递(Message Passing)范式

图神经网络 (GNN)主要是靠图卷积操作来完成的。而图卷积操作是一种将目标节点周围邻接节点的信息进行聚合的一种方法,:

\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)

 \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F为 (k-1) 层节点 i  的特征向量,  \mathbf{e}_{j,i} \in \mathbb{R}^D为 i 到 j 的边的特征向量。

 \square 为聚合方法(可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数;例如:sum()函数、mean()函数和max()函数 )

\gamma  和 \phi 分别是两个可以学习的层(如MLPs)

于是,根据这个公式,我们要做的就变成了三件事:

  • A. 邻接节点信息的变换:\phi
  • B. 邻接节点信息聚合:\square
  • C. 自己的信息与聚合后的邻接节点信息的变换:\gamma

在Pytorch Geometric(PyG) 中,这个流程被对应到self.propagate这个操作中,self.propagate将分别执行上述三件事:

  • A. 执行self.message,对应公式中 ,即邻接节点信息的变换
  • B. 执行self.aggregate,对应公式中  ,即邻接节点信息聚合
  • C. 执行self.update,对应公式中  ,即自己的信息与聚合后的邻接节点 信息的变换

MessagePassing实例

我们以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来实现一个简单的图神经网络。

GCNConv的数学定义为

\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right)

 

归一化系数计算对应\phi,求和对应聚合方法\square,该方法中没有变换\gamma。除此之外,需要在self.forward进行初始特征的一次变换\mathbf{\Theta},整体流程如下:

  • A. self.message:\phi计算归一化系数
  • B. self.aggregate:\square,选择add
  • C. self.update: \gamma(无)

其中,邻接节点的表征首先通过与权重矩阵相乘进行变换,然后按端点的度进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。

  2. 对节点表征做线性转换。

  3. 计算归一化系数。

  4. 归一化邻接节点的节点表征。

  5. 将相邻节点表征相加("求和 "聚合)。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # /space 选择聚合方法
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Step 1: 添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Step 2: /theta 初始特征的一次变换
        x = self.lin(x)

        # Step 5: 相邻节点表征相加("求和 "聚合)
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, size):
        # Step 3:/phi 计算归一化系数
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # Step 4: 归一化邻接节点的节点表征
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # /gamma 无
        return aggr_out

 完整代码: 

Planetoid数据集类的官方文档为torch_geometric.datasets.Planetoid

# -*- coding: utf-8 -*-
"""
Created on Sat Jun 19 11:37:18 2021

@author: Choi
"""

import os
import torch
from torch_geometric.datasets import Planetoid # PyG处理好的一些数据,如"Cora", "CiteSeer" and "PubMed" ,用Planetoid这个类调用即可
import torch_geometric.nn as pyg_nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

# load Cora dataset
def get_data(folder="node_classify/cora", data_name="cora"):
    """
    :param folder:保存数据集的根目录。
    :param data_name:数据集的名称
    :return:返回的是一个对象,就是PyG文档里的Data对象,它有一些属性,如 data.x、data.edge_index等
    """
    dataset = Planetoid(root=folder, name=data_name)
    return dataset

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # /space 选择聚合方法
        self.lin = torch.nn.Linear(in_channels, out_channels)
 
    def forward(self, x, edge_index):
        # Step 1: 添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Step 2: /theta 初始特征的一次变换
        x = self.lin(x)
 
        # Step 5: 相邻节点表征相加("求和 "聚合)
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
 
    def message(self, x_j, edge_index, size):
        # Step 3:/phi 计算归一化系数
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # Step 4: 归一化邻接节点的节点表征
        return norm.view(-1, 1) * x_j
 
    def update(self, aggr_out):
        # /gamma 无
        return aggr_out
    
    
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 配置GPU
    
    dataset = get_data()
    data = dataset[0]
     
    net = GCNConv(data.num_features, 64)
    h_nodes = net(data.x, data.edge_index)
    print(h_nodes.shape)
    
if __name__ == "__main__":
    main()

 

其他资料:

1.Torch geometric GCNConv 源码分析   

 

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值