图注意力神经网络的pytorch代码解析
1.图注意力神经网络的原理简介
图注意网络的原理介绍有很多,可以参考知乎文章:向往的GAT(图注意力模型)。作者是清华大学的一个博士,他写的图卷积原理非常透彻,这里对于图注意力的描述也很好。
为了让后面的代码介绍更清楚,本文再重述一下注意力公式和多头注意力的原理。
1.1 注意力机制的公式
(1)计算注意力系数(attention coefficient)
对于顶点 i i i ,逐个计算它的邻居们( j ∈ N i j\in{N_i} j∈Ni)和它自己之间的相似系数:
e i j = a ( [ W h i ∣ ∣ W h j ] ) , j ∈ N i e_{ij}=a([Wh_i||Wh_j]),j\in{N_i} eij=a([Whi∣∣Whj]),j∈Ni
h i h_i hi与 h j h_j hj分别为中心节点及其邻居节点的特征。 W W W的作用在于对特征进行映射,提高特征的表达能力, [ ∗ ∣ ∣ ∗ ] [*||*] [∗∣∣∗]表示拼接,将映射之后的特征进行组合,并通过 a ( ∗ ) a(*) a(∗)映射成一个实数,作者通过单层前馈神经网络实现。然后通过类似于softmax的方法求解注意力系数:
α i j = e x p ( L e a k y R e L U ( e i j ) ) ∑ k ∈ N i e x p ( L e a k y R e L U ( e i k ) ) \alpha_{ij}=\frac{exp(LeakyReLU(e_{ij}))}{\sum_{k\in{N_i}}exp(LeakyReLU(e_{ik}))} αij=∑k∈Niexp(LeakyReLU(eik))exp(LeakyReLU(eij))
(2)特征的聚合
将计算好的注意力系数作为融合权重,对邻居节点的特征进行聚合:
h i ′ ( K ) = δ ( ∑ j ∈ N i α i j W h j ) h^{'}_{i}(K) = \delta(\sum_{j\in{N_i}}\alpha_{ij}Wh_j) hi′(K)=δ(j∈