Torch.geometric==2.3.1中GATConv代码流程解析

供学习参考使用,()中内容表示操作后的张量shape,请与源码对照观看

  • 节点特征经过全连接层进行线性变换(num_n, feature_n)->(num_n, heads, out_dim)
  • 将变换后的节点特征分源节点和宿节点与分别源宿权重系数att(1, heads, out_dim)相乘,得到((num_n, heads),  (num_n, heads)),表示节点作为源和宿节点时的重要性系数
  • 为图添加自环边
  • 将边特征经过全连接层进行线性变换(num_edge, heads, out_dim),并与边权重系数相乘,得到边的重要性系数(num_edge, heads)
  • 根据edge_index(2, edge_num)中的源和宿节点索引,对重要性系数索引,得到按edge_index顺序排列的源和目标重要性系数(num_edge, heads)
  • 将边、源节点、宿节点重要性系数相加(相当于完成拼接),并经过负半平面斜率为0.2LeakyRelu激活,得到从源到宿节点传递信息的注意力系数
  • 注意力系数按相同宿节点进行softmax归一化处理(num_edge, heads)
  • 根据edge_index中的源节点索引,对线性变换后的节点特征索引,得到按edge_index顺序排列的源节点特征(num_edge, heads, out_dim)
  • 归一化后的注意力系数与源节点特征对应相乘(num_edge, heads, out_dim)—message()
  • 创建(num_n, heads, out_dim)的零张量,将传递的信息,按照宿节点索引(num_edge, )->(num_edge, heads, out_dim),使信息(num_edge, heads, out_dim)中的每个节点的每个head的每一维信息,按照对应位置的索引,加到零张量被索引节点对应head的对应元素上,完成信息汇聚---aggregate()
  • 更新节点属性,得到经图注意力层处理后的节点特征(num_n, out_dim)[mean]/(num_n, heads*out_dim)[concat]---update()
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值