消息传递神经网络(pyG实现GCN层)

一、引言

为节点生成节点表征是图计算任务成功的关键,神经网络的生成节点表征的操作叫做节点嵌入(node embeddi ng)

二、消息传递范式介绍

在这里插入图片描述

基于消息传递范式的生成节点表征的过程:
在这里插入图片描述
我们从左往右来看此图。图的左边是我们输入的整张图(INPUT GRAPH),由ABCDEFG六个节点组成,现在目标是得到更新之后A节点(target node)表示。
再看右边,与A相邻的BCD三个点进行变换聚合后就会得到更新后的A节点信息。同理所有的节点都与它相邻节点有关,更新后的信息为相邻节点变换和聚合后的特征信息。消息传递图神经网络是指遵循“消息传递范式”的图神经网络,此类图神经网络实现了上述的节点信息更新过程。

注:未经过训练的图神经网络生成的节点表征还不是好的节点表征,好的节点表征可用于衡量节点之间的相似性。

三、消息传递的实现(pyG)

1、MessagePassing基类

Pytorch Geometric(PyG)提供了 MessagePassing基类,它封装了“消息传递”的运行流程。通过继承 MessagePassing 基类,可以方便地构造消息传递图神经网络,构造一个最简单的消息传递图神经网络类,我们只需定义聚合和更新的方法即可。

  • MessagePassing(aggr=“add”, flow=“source_to_target”,node_dim=-2) :

     aggr :定义要使用的聚合方案("add"、"mean "或 "max");
     flow :定义消息传递的流向("source_to_target "或"target_to_source");
     node_dim :定义沿着哪个轴线传播
    
  • MessagePassing.propagate(edge_index, size=None,**kwargs) :

     开始传递消息的起始调用,在此方法中 message 、 update 等方法被调用。
     它以 edge_index (边的端点的索引)和 flow (消息的流向)以及一些额外的数据
     为参数。
    
  • MessagePassing.aggregate(…) :

    将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式
    有 sum , mean 和 max 。
    
  • MessagePassing.message(…):

     接收传递给MessagePassing.propagate(edge_index, size=None,**kwargs) 方法的所有参数
    
  • MessagePassing.update(aggr_out, …) :

    每个节点更新节点表征。此方法以aggregate 方法的输出为第一个参      
    数,并接收所有传递给propagate() 方法的参数。
    

2、继承MessagePassing实现GCNConv

在这里插入图片描述将上面的公式写成代码

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import warnings
warnings.filterwarnings("ignore")


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息从源节点传播到目标节点
        self.lin = torch.nn.Linear(in_channels,out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]  edge_index (边的端点的索引)

        # Step 1: 在邻接矩阵中加入自循环的边
        edge_index, _ = add_self_loops(edge_index,num_nodes=x.size(0))
        # Step 2: 线性变换节点特征矩阵。
        x = self.lin(x)
        # Step 3: 计算归一化系数norm
        #归一化系数是由每个节点的节点度
        #return edge_index, edge_weight
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        ## Step 4-5: 消息传递
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: 节点特征归一化.
        return norm.view(-1, 1) * x_j

GCNConv 继承了 MessagePassing 并以"求和"作为领域节点信息聚合方式。
该层的所有逻辑都发生在其 forward() 方法中。

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值