torch_geometric.nn.MessagePassing使用

torch_geometric.nn.MessagePassing使用

示例

torch_geometric.nn中有多种MessagePassing类可以使用。这些类的共同点是可以从图中接收消息并在节点之间进行传递。下面是使用MessagePassing类的一般步骤:

  1. 导入必要的模块
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
  1. 定义MessagePassing类
class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization term.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-6: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        
        # Step 6: Return new node embeddings.
        return aggr_out
  1. 创建实例并使用
in_channels, out_channels = 16, 32
model = GCN(in_channels, out_channels)

x = torch.randn((10, in_channels))  # Node feature matrix
edge_index = torch.tensor([
    [0, 1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8],
    [1, 0, 2, 1, 4, 5, 3, 3, 7, 8, 6, 9]
])  # Edge indices

out = model(x, edge_index) # out.shape:[10, 32]

上述例子中,我们创建了一个简单的GCN模型,可以应用于输入特征为16个的图。在forward方法中,我们首先将self-loops加入到图中,然后将特征传递到线性层中。接下来,我们计算了边权重的标准化常数,然后使用propagate方法将消息从源节点传递到目标节点。在message方法中,我们对接收到的特征进行了标准化,然后在update方法中将聚合结果作为新的节点嵌入返回。最后,我们使用创建的GCN实例对输入数据进行了轻松的推断。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值