目录
② propagate(self, edge_index, size, 其他参数)
本教程中,基于消息传递模型(MessPassing基类)的GNN在coding时主要需要实现 3个部分,foward(...)函数、message(...)函数和aggregate(...)函数。浅谈一下对函数执行过程的初步理解,以下内容均未仔细查阅文档/debug验证。
① forward(...)函数:
该函数总揽全局,主要实现:数据的预处理、调用propagate(...)函数,数据的后处理,三个过程。
其中,调用propagate(...)函数时传入的参数可分四部分:self、edge_index、size、其他参数args(例如wjunjie=(2,3),为什么要写成这种序对的形式,可能是规定吧。第一个数“2”在propagate(...)、message(...)等之后的函数中被记作wjunjie_j,第二个数记作wjunjie_i,i和j分别表示节点作为“主节点”和“邻居节点”的值,一般是一样的,例如节点表征矩阵X=(X, X)),前2个必须要有,size没有的话可自动计算,其他参数随意。
一般,节点表征矩阵X当作其他参数来处理,但X不能随意,这个一般必须要有。
② propagate(self, edge_index, size, 其他参数)
在propagate函数中,程序会调用message(...)函数和aggregate(...)函数,以完成“邻居节点”表征变换操作和表征聚合操作。
其中传入message和aggregate的参数主要包括三类:self、foward(...)函数传入propagate(...)函数的自定义参数、根据参数edg_index和X计算出的一些量(都有哪些量?猜测如下源码截图:)。
③ message(...)函数
源码中的message函数只有self和x_j两个参数,如下第一个message函数截图所示。该函数表示“邻居节点们”的表征矩阵(即x_j矩阵)应该进行哪些变化。
除了这两个参数,我们在重写message函数时也可以添加一些其他参数,但原则是这些参数确实被传入了message函数,都有哪些参数呢?即propagage(...)函数中所写的三类参数:self、自定义参数以及propagage计算出的参数,如下第二个message函数截图所示,x和wjunije是自定义的,index、ptr等是propagage函数计算得到的。
④ aggregate(...)函数:
对前面变化的信息进行聚合,源码中定义如下图(源码定义的参数就挺全了),inputs是message函数的输出(当然源码中inputs可能是对message的输出进行了一些小处理后的结果,先不管),然后其他参数大都是“propagage函数中计算得到的”,如index(该量可以看作等于edge_index,实际一不一样,本人并未验证过,仅看到源码中有形似index=edge_index的几句代码)。
总的来说,重写aggregate函数时需要定义的参数也都必须是propagage函数中确实传入到了aggregate函数的。一般必须要有的参数有self、inputs(节点表征矩阵)以及index(保存有邻接关系)3个量。
下图源码的aggregate函数中最后一行有reduce=self.aggr,所以aggregate(...)函数一般也不需要重写,在GNNConv的__init__(...)函数中初始化一些self.aggr就行了(没试过,需要的时候再试吧)。
End...