图神经网络 torch_geometric 库的 MessagePassing 运行机制学习

torch_geometric.nn.conv. MessagePassing( )
继承这个类,可以自定义节点信息传播机制

例子

import torch
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.conv import MessagePassing

class GCNConv(MessagePassing):
    def __init__(self):
    	#选择相加的方式进行邻居节点信息聚合
        super().__init__(aggr='add')

    def forward(self, x, edge_index):
    	#给图添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        print(edge_index)
        out = self.propagate(edge_index, x=x)
        print(out)

    def message(self, x_j):
        print(x_j)
        return x_j




edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [6], [1]], dtype=torch.float)
edge_index = edge_index.permute(1, 0)
model = GCNConv( )
out = model(x, edge_index)

运行截图
在这里插入图片描述
MessagePassing的运行机制就是用行坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,然后用列坐标[1,0,2,1,0,1,2]进行 add 聚合信息,x_j其实就是根据行坐标得来的,行坐标里面的每一个元素其实就是一个节点标号,它告诉我们当前聚合信息时,每一个节点的信息应该是怎么样,在这里我没有转换节点feature,直接就是初始feature进行聚合,然后列坐标的元素进行聚合,如:列坐标中0节点与行节点对应的元素为1,0,所以在x_j对应位置找到元素6,-1然后相加得5,同理,1节点为-1+1+6 =6,2节点为6+1=7;
需要注意的是def message(self, x_j)中x_j的参数名字不能随便改变,不然会出错;其实x_j可以变为x_i,x_i代表以列坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,但仍以列坐标[0,1,1,2,0,1,2]汇聚信息;最后得到的结果如下:
在这里插入图片描述
也可以def message(self, x_j,x_i) ,其中x_j,x_i同时返回,可以根据具体应用进行灵活操作;

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值