最近论文中需要使用图卷积神经网络(GNN),看了一些关于GCN的代码,还有基于PyTorch Geometric Temporal的代码实现,在这里做一下记录。
GCN原始代码
关于GCN的原理在这里不进行过多阐述,其他文章里面解释的已经很详细了,这里就直接进入到代码的部分。GCN的公式如下:
其中为邻接矩阵;
为t时刻输入的节点的特征矩阵;
是近似的图卷积滤波器,其中
=
+
(
是N维的单位矩阵);
是度矩阵;
代表需要神经网络训练的权重矩阵;
是激活函数Relu。
根据公式逐步实现GCN的代码如下:
def get_gcn_fact(adj):
'''
Function to