PyTorch Geometric自定义模型

本文详细介绍了如何在PyTorch Geometric(PyG)中自定义图神经网络(GNN)模型,包括理解MessagePassing机制、实现propagate(), message()和update()函数,以及通过实例解释了不同GNN层如GAT和GCN的实现细节。通过学习,读者将能够更好地掌握PyG库并创建自己的GNN模型。" 107626836,8765248,SQL Server:sp_MSforeachtable 存储过程详解与用法,"['数据库', 'SQL Server', '存储过程']
摘要由CSDN通过智能技术生成

这块的内容介绍相当的少,官网给出的doc也描述的没有非常清晰,以至于我刚开始学的时候非常云里雾里。后来慢慢明白了之后,才慢慢看懂官网的doc。
首先给出官网的doc:
在这里插入图片描述
这一块我们要注意几个点:方块代表的是聚合邻居信息的操作,比方说sum,mean,max,γ代表的是结合上一跳邻居信息和节点信息的更新操作,φ代表的是接受邻居信息的操作。
在这里插入图片描述了解了这些知识,再看PyG的注解就容易了,我们需要实现的是这几个函数,就可以自定义GNN的一层:
在这里插入图片描述
总结下来就这几点:

  1. MessagePassing定义邻居信息的聚合器,即图2的Aggregation
  2. MessagePassing.propagate()是开始传播信息的函数,下面的message()函数和update()函数中需要的参数都要被包含进这个函数的参数中。
  3. message()函数是定义 x j x_j xj如何传播到节点 i i i的函数,即图2的φ函数。比方说,在GAT中, x j x_j xj传播到节点 i i i需要乘上注意力,那么就应该在这个函数里定义。值得注意的是,所有在propagate()中定义过的参数都可以在message()中直接使用,而且还可以在参数后面加上_j指定节点。比方说,我们在在propagate()中定义了参数x表示所有节点的features,x大小为[N,features],那么我们就可以使用x_ix_j表示节点j的特征,大小是[E,features],为什么大小是[E,features]呢,我理解这里使用x_j代表着所有边的源节点的特征,x_i代表着所有边的目标节点的特征。
  4. 更新节点embedding的函数,对应于图二的γ,和message()函数一样可以直接使用propagate()中的参数,默认参数是aggr_out,这个参数代表着经过上一步的message()之后所有节点获得的邻居节点的表示,大小是[N,in_channels]

MessagePassing—定义邻居信息聚合器
propagate():传播信息,指定要传播的信息和边的集合
message():接收信息,返回 x j x_j xj传到 x i x_i xi之后变成了啥
update():更新节点 i i i表示。

在这里插入图片描述
在这里插入图片描述


来看一个小栗子:
在这里插入图片描述
我们对照图2可以看出,此时的γ是没有的,也就是说不需要实现update()函数,Aggregation是max,也就是说我们的MessagePassing需要定义为max,propagate()正常传播就行,message()里面比较复杂,在接收到

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

canaryW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值