PyG的MessagePassing基类中self.propagate函数的消息传递机制

本文介绍了图神经网络(GNN)中的MyGCNConv类,重点讲解了其消息传递机制,包括添加自环、节点特征变换、message函数、aggregate函数的使用。通过实例展示了如何在PyG中结合代码理解这些关键步骤。
摘要由CSDN通过智能技术生成

1. GNN基本理论

参考文章:CS224w 图神经网络(Graph Neural Networks) - 知乎Hi,大家好,这里是居家隔离的糖葫芦喵喵~! 在之前的内容里我们讨论了图像和自然语言的机器学习方法以及简单的强化学习方法,今天开始我们要接触到机器学习的另一个有趣的领域——图机器学习。下面为大家带来斯坦…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/113862170

 2. 结合代码了解PyG的消息传递机制

class MyGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')    # 'add' Aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        # x has shape [N, out_channels]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
        # out has shape [N, out_channels]

        return out

    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out


if __name__ == '__main__':
    dataset = DataLoader('texas')
    data = dataset[0]

    layer = MyGCNConv(1703, 100)
    result = layer(data.x, data.edge_index)

 首先,图信号data.x和边集data.edge\_index在构造时被传入之后,先运行forward函数,在边集中增加自环,然后经历一个线性变换(初始的特征变换)。然后就到了最关键的self.propagate函数,在self.propagate函数中,会自动调用self.message函数、self.aggregate函数、self.update函数,下面详细讲一下这三个函数的机制:

2.1 self.message: 邻居信息的变换 

首先提一点,message函数可以接收任何从propagate函数传入的参数。

理解message函数要理解里面的x\_ix\_j到底是什么,x\_ix\_j都是shape为E\times out\_channels的向量,(E代表边的数量)。所以实际上它是把edge_index中的两个1\times E的向量拿出来,x\_i对应第一个向量,x\_j对应第二个向量。x\_i就是第一个向量中所有边的信号,x\_j就是第二个向量中所有边的信号。

举个例子:假设有一个3结点的图,图信号矩阵和边集(添加自环后)为

x=\bigl(\begin{smallmatrix} 1 & 2 & 3 & 4 & 5\\ 6 & 7 & 8 & 9 & 10\\ 11 & 12 &13 & 14 &15 \end{smallmatrix}\bigr)edge\_index=\bigl(\begin{smallmatrix} 0 & 0 & 1 & 2 & 0 & 1 & 2\\ 1 & 2 & 0 & 0 & 0 & 1 & 2 \end{smallmatrix}\bigr)

本卷积层需要完成一个5\rightarrow 2的变换,即out\_channels=2,假设经历了初始线性变换之后图信号变换为

x=\bigl(\begin{smallmatrix} 1 & 2 \\ 3 & 4\\ 5 & 6\end{smallmatrix}\bigr),那么有:

x\_i=\bigl(\begin{smallmatrix} 1 & 2\\ 1 & 2\\ 3 & 4\\ 5 & 6\\ 1 & 2\\ 3 & 4\\ 5 & 6 \end{smallmatrix}\bigr),其分别对应着0 0 1 2 0 1 2号结点(edge_index中的第1行),即所有头结点的图信号

x\_j=\bigl(\begin{smallmatrix} 3 & 4\\ 5 & 6\\ 1 & 2\\ 1 & 2\\ 1 & 2\\ 3 & 4\\ 5 & 6 \end{smallmatrix}\bigr),其分别对应着1 2 0 0 0 1 2号结点(edge_index中的第2行),即所有尾结点的图信号。

前面提到,message函数对应着邻居信息的变换,这里假设对邻居信息的变换具体指对邻居的信号乘以2,则直接返回2*x_j即可,即

    def message(self, x_j):
        return 2 * x_j

以上,message就完成了自己的任务,即对邻居信息的变换

可以理解为,在message函数中就是对edge_index做一个scatter,可以规避掉矩阵相乘的高复杂度计算。 

2.2 self.aggregate: 邻居信息的聚合

聚合方式可以在构造函数中通过aggr参数选择(如add/mean/max),也可以自己实现,假设选择'add'聚合函数,即加和。下面看一下在aggregate函数内部,消息传递机制是如何实现的。

上述message函数的返回结果传递到了aggregate函数中,变换后的邻居结点信号为:

x\_j=2\times \bigl(\begin{smallmatrix} 3 & 4\\ 5 & 6\\ 1 & 2\\ 1 & 2\\ 1 & 2\\ 3 & 4\\ 5 & 6 \end{smallmatrix}\bigr)=\bigl(\begin{smallmatrix} 6 & 8\\ 10 & 12\\ 2 & 4\\ 2 & 4\\ 2 & 4\\ 6 & 8\\ 10 & 12 \end{smallmatrix}\bigr)

假设当前节点是1号结点,aggregate函数会根据edge_index找到其所有邻居节点(即0号和1号结点),然后对他们的信号进行一个加和操作。

edge\_index=\bigl(\begin{smallmatrix} 0 & 0 & 1 & 2 & 0 & 1 & 2\\ 1 & 2 & 0 & 0 & 0 & 1 & 2 \end{smallmatrix}\bigr),找到1号结点所有的邻居结点,即x_j的第3行和第6行,然后加和,结果为(2,4)+(6,8)=(8,12),对其余所有结点进行相同的操作,就完成了aggregate函数的任务:邻居信息的聚合。结果为:

aggr\_out=\bigl(\begin{smallmatrix} 18 & 24\\ 8 & 12\\ 12 & 16 \end{smallmatrix}\bigr)

2.3 self.update: 自己的信息与聚合后的邻居信息的变换

没写完

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值