文章来源于机器之心DGL专栏,作者:张昊、李牧非、王敏捷、张峥。
图卷积网络 Graph Convolutional Network (GCN) 告诉我们将局部的图结构和节点特征结合可以在节点分类任务中获得不错的表现。美中不足的是 GCN 结合邻近节点特征的方式和图的结构依依相关,这局限了训练所得模型在其他图结构上的泛化能力。
Graph Attention Network (GAT) 提出了用注意力机制对邻近节点特征加权求和。邻近节点特征的权重完全取决于节点特征,独立于图结构。
在这个教程里我们将:
- 解释什么是 Graph Attention Network
- 演示用 DGL 实现这一模型
- 深入理解学习所得的注意力权重
- 初探归纳学习 (inductive learning)
难度:★★★★✩(需要对图神经网络训练和 Pytorch 有基本了解)
在 GCN 里引入注意力机制
GAT 和 GCN 的核心区别在于如何收集并累和距离为 1 的邻居节点的特征表示。
在 GCN 里,一次图卷积操作包含对邻节点特征的标准化求和:
其中 N(i) 是对节点 i 距离为 1 邻节点的集合。我们通常会加一条连接节点 i 和它自身的边使得 i 本身也被包括在 N(i) 里。
是一个基于图结构的标准化常数;σ是一个激活函数(GCN 使用了 ReLU);W^((l)) 是节点特征转换的权重矩阵,被所有节点共享。由于 c_ij 和图的机构相关,使得在一张图上学习到的 GCN 模型比较难直接应用到另一张图上。解决这一问题的方法有很多,比如 GraphSAGE 提出了一种采用相同节点特征更新规则的模型,唯一的区别是他们将 c_ij 设为了|N(i)|。
图注意力模型 GAT 用注意力机制替代了图卷积中固定的标准化操作。以下图和公式定义了如何对第 l 层节点特征做更新得到第 l+1 层节点特征:
对于上述公式的一些解释:
- 公式(1)对 l 层节点嵌入
做了线性变换,W^((l)) 是该变换可训练的参数
- 公式(2)计算了成对节点间的原始注意力分数。它首先拼接了两个节点的 z 嵌入,注意 || 在这里表示拼接;随后对拼接好的嵌入以及一个可学习的权重向量 做点积;最后应用了一个 LeakyReLU 激活函数。这一形式的注意力机制通常被称为加性注意力,区别于 Transformer 里的点积注意力。
- 公式(3)对于一个节点所有入边得到的原始注意力分数应用了一个 softmax 操作,得到了注意力权重。
- 公式(4)形似 GCN 的节点特征更新规则,对所有邻节点的特征做了基于注意力的加权求和。
出于简洁的考量,在本教程中,我们选择省略了一些论文中的细节,如 dropout, skip connection 等等。感兴趣的读者们欢迎参阅文末链接的模型完整实现。
本质上,GAT 只是将原本的标准化常数替换为使用注意力权重的邻居节点特征聚合函数。
GAT 的 DGL 实现
以下代码给读者提供了在 DGL 里实现一个 GAT 层的总体印象。别担心,我们会将以下代码拆分成三块,并逐块讲解每块代码是如何实现上面的一条公式。
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(se