task02:消息传递范式

本文介绍了消息传递范式在图神经网络中的应用,详细阐述了节点表征生成的过程,并以PyTorch Geometric中的MessagePassing基类为例,解释如何构建消息传递图神经网络。通过message()和update()函数定义,以及不同的消息聚合方案,如add、mean和max,实现边缘卷积等操作。
摘要由CSDN通过智能技术生成

简单理解

消息传递范式是实现图神经网络的一种通用范式。消息传递范式遵循“消息传播-》消息聚合-》消息更新”这一过程,实现将邻接节点的信息聚合到中心节点上。messagePassing类大大方便了我们图神经网络的构建。

基于消息传递范式的生成节点表征的过程

在这里插入图片描述

1:在图的最右侧,B节点的邻接节点(A,C)的消息传递给了B,经过信息变换得到了B的嵌入,C,D节点相同
2:在图的最右侧,A节点的邻接节点(B,C,D)的之前得到的节点嵌入传递给了节点A;在图的左侧,聚合得到的信息经过信息变换得到了A节点新的嵌入。
3:重复多次,我们可以得到每一个节点的多次信息变换的嵌入。这样的节点经过多次信息聚合与变换的节点嵌入就可以作为节点的表征,可以用于节点的分类。
在这里插入图片描述

Pytorch Geometric中的MessagePassing基类

Pytorch Geometric(PyG)提供了MessagePassing基类,它实现了消息传播的自动处理,继承该基类可使我们方便地构造消息传递图神经网络,我们只需定义函数 ,即message() 函数,和函数 ,即 update()函数,以及使用的消息聚合方案,即aggr=“add” 、aggr="mean"或 aggr=“max” 。

继承MessagePassing类的GCNConv

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

实现边缘卷积

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值