pytorch_geometric:message passing neural networks ( 以GCNConv 为例 ) 博客传送门

官方文档

message passing networks

Torch geometric GCNConv 源码分析

补充说明

这里附上以上博客中提到的D,A

这里附上综合以上博客,对 GCNConv 注释后的代码

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels): # 构造的时候必须输入in,out
        super(GCNConv,self)._init_(aggr='add')
        self.lin = torch.nn.Linear(in_channels,out_channels)

    def forward(self, x, edge_index): # 调用的时候必须输入 x, edge_index
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        #### Steps 1-2 are typically computed before message passing takes place.
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index,_ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Step 2: Linearly transform node feature matrix. 压缩 node feature
        x = self.lin(x)

        #### Steps 3-5 can be easily processed using the torch_geometric.nn.MessagePassing base class.
        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)#得到x_j


    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]

        # Step 3: Normalize node features.
        row,col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype) # [N, ] dtype是数据类型
        deg_inv_sqrt = deg.pow(-0.5) # [N(-0.5), ]
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # (N(-0.5),N(-0.5),E)

        # step 4:Aggregation.
        return norm.view(-1,1) * x_j # [N, E]*[E, out_channels]=[N, out_channels]


    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 5: Return new node embeddings.
        return aggr_out

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

梦dancing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值