1. GNN基本理论
2. 结合代码了解PyG的消息传递机制
class MyGCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # 'add' Aggregation
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# x has shape [N, out_channels]
# Step 4-5: Start propagating messages.
out = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
# out has shape [N, out_channels]
return out
def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
if __name__ == '__main__':
dataset = DataLoader('texas')
data = dataset[0]
layer = MyGCNConv(1703, 100)
result = layer(data.x, data.edge_index)
首先,图信号和边集在构造时被传入之后,先运行forward函数,在边集中增加自环,然后经历一个线性变换(初始的特征变换)。然后就到了最关键的self.propagate函数,在self.propagate函数中,会自动调用self.message函数、self.aggregate函数、self.update函数,下面详细讲一下这三个函数的机制:
2.1 self.message: 邻居信息的变换
首先提一点,message函数可以接收任何从propagate函数传入的参数。
理解message函数要理解里面的和到底是什么,和都是shape为的向量,(代表边的数量)。所以实际上它是把edge_index中的两个的向量拿出来,对应第一个向量,对应第二个向量。就是第一个向量中所有边的信号,就是第二个向量中所有边的信号。
举个例子:假设有一个3结点的图,图信号矩阵和边集(添加自环后)为
,
本卷积层需要完成一个的变换,即,假设经历了初始线性变换之后图信号变换为
,那么有:
,其分别对应着0 0 1 2 0 1 2号结点(edge_index中的第1行),即所有头结点的图信号
,其分别对应着1 2 0 0 0 1 2号结点(edge_index中的第2行),即所有尾结点的图信号。
前面提到,message函数对应着邻居信息的变换,这里假设对邻居信息的变换具体指对邻居的信号乘以2,则直接返回2*x_j即可,即
def message(self, x_j):
return 2 * x_j
以上,message就完成了自己的任务,即对邻居信息的变换
可以理解为,在message函数中就是对edge_index做一个scatter,可以规避掉矩阵相乘的高复杂度计算。
2.2 self.aggregate: 邻居信息的聚合
聚合方式可以在构造函数中通过aggr参数选择(如add/mean/max),也可以自己实现,假设选择'add'聚合函数,即加和。下面看一下在aggregate函数内部,消息传递机制是如何实现的。
上述message函数的返回结果传递到了aggregate函数中,变换后的邻居结点信号为:
假设当前节点是1号结点,aggregate函数会根据edge_index找到其所有邻居节点(即0号和1号结点),然后对他们的信号进行一个加和操作。
,找到1号结点所有的邻居结点,即x_j的第3行和第6行,然后加和,结果为,对其余所有结点进行相同的操作,就完成了aggregate函数的任务:邻居信息的聚合。结果为:
2.3 self.update: 自己的信息与聚合后的邻居信息的变换
没写完