供学习参考使用,()中内容表示操作后的张量shape,请与源码对照观看
- 节点特征经过全连接层进行线性变换(num_n, feature_n)->(num_n, heads, out_dim)
- 将变换后的节点特征分源节点和宿节点与分别源、宿权重系数att(1, heads, out_dim)相乘,得到((num_n, heads), (num_n, heads)),表示节点作为源和宿节点时的重要性系数
- 为图添加自环边
- 将边特征经过全连接层进行线性变换(num_edge, heads, out_dim),并与边权重系数相乘,得到边的重要性系数(num_edge, heads)
- 根据edge_index(2, edge_num)中的源和宿节点索引,对重要性系数索引,得到按edge_index顺序排列的源和目标重要性系数(num_edge, heads)
- 将边、源节点、宿节点重要性系数相加(相当于完成拼接),并经过负半平面斜率为0.2的LeakyRelu激活,得到从源到宿节点传递信息的注意力系数
- 将注意力系数按相同宿节点进行softmax归一化处理(num_edge, heads)
- 根据edge_index中的源节点索引,对线性变换后的节点特征索引,得到按edge_index顺序排列的源节点特征(num_edge, heads, out_dim)
- 归一化后的注意力系数与源节点特征对应相乘(num_edge, heads, out_dim)—message()
- 创建(num_n, heads, out_dim)的零张量,将传递的信息,按照宿节点索引(num_edge, )->(num_edge, heads, out_dim),使信息(num_edge, heads, out_dim)中的每个节点的每个head的每一维信息,按照对应位置的索引,加到零张量被索引节点对应head的对应元素上,完成信息汇聚---aggregate()
- 更新节点属性,得到经图注意力层处理后的节点特征(num_n, out_dim)[mean]/(num_n, heads*out_dim)[concat]---update()