GAT模型数学原理与代码详解(pytorch)

 这是一份用于图注意力网络GAT模型理解的入门教程,采用论文与代码结合的方式阐述GAT注意力机制的实现过程。本文关于数学原理部分不一定完全严谨,如有错误请在评论区指出。 

该模型来自论文:https://arxiv.org/abs/1710.10903icon-default.png?t=O83Ahttps://arxiv.org/abs/1710.10903本文着重分析pytorch框架的模型实现,代码来自开源网络:GitHub - Diego999/pyGAT: Pytorch implementation of the Graph Attention Network model by Veličković et. al (2017, https://arxiv.org/abs/1710.10903)


1 频域网络GCN与空域网络GAT

1.1 频域网络GCN

        在图神经网络(GNN)领域中发展出了许多不同的节点信息聚合方式,其中空域GNN与频域GNN是两个十分最重要的发展方向,在之前的文章中我们已经从0开始大致证明了空域GNN的典型算法之一GCN的实现方式:

GCN的基础入门及数学原理_gcn学习-CSDN博客

        在这里稍作复习,我们知道GCN节点信息的聚合方式是取决于图的拉普拉斯矩阵。为了应对空域中不规则图结构无法使用固定卷积核的问题,GCN采用“傅里叶变换”的思想,将图结构变换到频域,完成卷积后再采用逆变换将其变回空间域中。GCN本质上是在频域上对图信号进行滤波。通过拉普拉斯矩阵的特征分解,我们可以将图卷积视为一种频域滤波器。

1.2 空域网络GAT

        GAT(Graph Attention Network)是基于空域图卷积的思想,它通过学习每对节点间的注意力权重,来决定每个节点与邻居节点交互的强度。GAT的核心是通过自注意力机制来加权节点间的信息传播,而不是依赖固定的邻接矩阵。

其步骤如下:

  • 节点特征映射
  • 计算每个节点之前的注意力分数(权重)
  • 邻居信息加权求和
  • 多头注意力

其实现代码实现如下:

class GraphAttentionLayer(nn.Module):

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, x, adj):
        h = torch.bmm(x, self.W)
        N = x.size()[0]
        a_input = torch.cat([h.repeat(1, N).view(N * N, self.out_features), h.repeat(N, 1)], dim=1).view(N, N, 2 * self.out_features)
        attention = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        zero_vec = -9e15 * torch.ones_like(attention)
        attention = torch.where(adj > 0, attention, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.bmm(attention, h)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, n_final_out, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, n_final_out, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return x

2 GAT注意力机制的实现

        如果您了解过Transformer模型,那么其中的自注意力机制一定会让您印象深刻,GAT中的注意力机制与Transformer中的注意力机制有异曲同工之妙。如果您没有了解过Transformer也无妨,本文将以最简单易懂的语言来阐述注意力机制的奇妙实现方式。

       GAT中的注意力机制实则是在图中对每个节点之间连接的边赋予权重。如下图示例,GAT中注意力机制即是通过学习得到节点之间的“注意力评分”,可以看出,联系越密切的点其评分越高。

        在论文中注意力评分 e_{ij} 的计算公式如下:

e_{ij}=a\left(\left[Wh_i||Wh_j\right]\right),j\in\mathcal{N}_i

        公式描述了注意力评分的计算方法,其中 W 是一个可学习的矩阵,形状为(F_{in},F_{out}),其作用是将输入节点的特征映射到高维;h_i,h_j 是节点 i,j 的输入特征向量,形状为 (1,F_{in})a 是一个形状为 (1,2\times F_{out}) 的矩阵,其目的是将最终结果映射为一个特定的值,其具体计算流程如下:

        上图中,假设我们初始输入的每个节点都有两个特征,F_{in}=2, F_{out}=5 

        节点 i 和节点 j 的输入特征向量经过同一个可学习的参数矩阵 W,其原本的2维特征被映射到5维,之后文中对这两个升维后的特征进行拼接(concatenate),再经过另一个可学习的注意力参数矩阵 a 映射为一个特征值,这样就得到了节点 i 对于节点 j 的注意力分数。(注意,这个注意力是单向的,如果是节点 j 对于节点 i 则要重新换顺序计算一次,下文会给出更细致的解释)


        上图只是计算了一个方向的注意力分数,实际运用中我们一般是期望得到一个注意力矩阵,矩阵的形状一般为 (N\times N) ,其中为 N 节点个数,具体实现流程如下图:

        从最简单的例子入手,假如有一个两个节点的图,每个节点有一个输入特征,矩阵W将特征升高到3维度,即 N=2, F_{in}=1, F_{out}=3,我们期望得到的是一个 2\times2 的注意力矩阵,将该示例带入代码中,得到的流程如下图所示:

        下面以代码角度阐述计算过程:

h = torch.bmm(x, self.W)

h=xW

        得到结果:

h=\begin{bmatrix}[1.0,2.0,3.0],\\ [4.0,5.0,6.0]\end{bmatrix}

        可以看到经过线性矩阵 W 的作用后,两个节点的特征从1维扩展到3维。

h.repeat(1, N)

        执行 h.repeat(1, N),意思是将 h 沿着第二维(特征维度)重复 N 次。因为 N=2,所以会得到以下的张量,形状为 【2,3】(2 是节点数,6 是特征维度):

h_{\mathrm{repeat1}}=\begin{bmatrix}[1.0,2.0,3.0,1.0,2.0,3.0],\\ [4.0,5.0,6.0,4.0,5.0,6.0]\end{bmatrix}

h.repeat(N, 1)

        接着,执行 h.repeat(N, 1),意思是将 h 沿着第一维(节点数维度)重复 N 次。因为 N=2,所以得到以下的张量,形状为 【4,3】(4 是节点数,3 是特征维度):

h_{\mathrm{repeat2}}=\begin{bmatrix}[1.0,2.0,3.0],\\ [4.0,5.0,6.0],\\ [1.0,2.0,3.0], \\ [4.0,5.0,6.0]\end{bmatrix}

torch.cat([h.repeat(1, N), h.repeat(N, 1)], dim=2)

        这个操作沿着第二维(特征维度)拼接两个重复的张量。

  •  第一个张量是 h.repeat1(1, N),它的形状是 【2, 6】。
  •  第二个张量是 h.repeat2(N, 1),它的形状是 【4, 3】。

        执行 torch.cat 后,它们沿着第二维拼接,结果如下:

a_{\mathrm{input}}=\begin{bmatrix}[1.0,2.0,3.0,1.0,2.0,3.0],\\ [1.0,2.0,3.0,4.0,5.0,6.0],\\ [4.0,5.0,6.0,1.0,2.0,3.0],\\ [4.0,5.0,6.0,4.0,5.0,6.0]\end{bmatrix}

        这个矩阵 a_{input} 的形状是 【2, 2, 6】,也就是每对节点的特征向量拼接成一个长度为 6 的向量。

torch.cat([h.repeat(1, N), h.repeat(N, 1)], dim=2).view(N, N, 2 * self.out_features)

        最后,通过 .view(N, N, 2 * C) 将张量重新调整为形状 【2, 2, 6】,即每一对节点(共 2x2=4 对)都被拼接了它们的特征,形状为 【N, N, 2 * C】,可表示为:

a_{\text{input}}=\begin{bmatrix}[[1.0,2.0,3.0,1.0,2.0,3.0],[1.0,2.0,3.0,4.0,5.0,6.0]],\\ [[4.0,5.0,6.0,1.0,2.0,3.0],[4.0,5.0,6.0,4.0,5.0,6.0]]\end{bmatrix}

torch.matmul(a_input, self.a).squeeze(2)

        采用形状为 (2\times F_{out},1)\rightarrow(6,1) 可学习的注意力参数矩阵 a,使用 torch.matmul 批量矩阵乘法得到最初的注意力矩阵 e_{\text{initial}}

a=\begin{bmatrix}1.0, 2.0, 3.0, 4.0, 5.0, 6.0\end{bmatrix}^T

e=\begin{bmatrix} 46&91\\ 64&109\end{bmatrix}

attention = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
attention = F.softmax(attention, dim=1)

        最终采用 leakyrelu 进行非线性操作,使用softmax归一化后得到最终结果:

\alpha_{ij}=\dfrac{exp(LeakyReLU(e_{ij}))}{\sum_{k\in\mathcal{N}_{i}}exp(LeakyReLU(e_{ik}))}

        接下来我们就可以采用计算好的注意力系数对节点进行特征更新了: 

\vec{h}_i^{\prime}=\sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}\mathbf{W}\vec{h}_j\right)

        其中 \vec{h}_i^{\prime} 是每个节点更新后的特征向量,\sigma(\cdot) 是激活函数。


        值得注意的是,对于Transformer的注意力机制来说,作用是求取全局注意力,例如一个包含10个单词的句子,Transfromer期望得到一个 10\times 10 填满的注意力矩阵。然而在GNN领域中我们更注重于节点相邻接的节点对其影响,而不是全局结果。

        对于上文的例子来说,我们实现的功能的计算出每个节点之间的注意力评分,然而这是全局注意力。以下图为例,Node5于Node4并没有直接相连,在这里计算这两点之间的注意力评分是多余的,为了将图结构加入到注意力机制中,作者使用了MASK方法:

zero_vec = -9e15 * torch.ones_like(attention)
attention = torch.where(adj > 0, attention, zero_vec)

        MASK方法的原理也十分简单,就是使用图结构的邻接矩阵于注意力句子的哈达玛乘积(Hadamard product)将不连接节点之间的注意力分数转化为无穷小,即可得到期望的结果。


3 多头注意力机制

        关于多头注意力机制的引入,作者在文中这样写到:

 简化来说就是为了增强模型的稳定性与准确性。文中公式如下:

        式中 K 代表了注意力头的个数,将 K 个头的注意力机制拼接,最终得到的特征向量的维度也变为原来的 K 倍,即 F'\rightarrow KF'

        特别的,如果需要在网络的最后(预测)层上执行多头注意力,那么采用平均值即可得到期望的结果:

\vec{h}_i^{\prime}=\sigma\left(\frac1K\sum_{k=1}^K\sum_{j\in\mathcal{N}_i}\alpha_{ij}^k\mathbf{W}^k\vec{h}_j\right)

        图中每个不同颜色的波浪线即代表不同的注意力头,共同作用更新节点信息。


GAT(Graph Attention Network)是一种基于图神经网络模型,用于处理图数据。PyTorch是一种深度学习框架,用于构建、训练和部署神经网络模型。下面是关于GAT代码PyTorch中的解释: 在PyTorch中实现GAT代码主要包括以下几个步骤: 1. 数据准备:首先,需要准备图数据的节点特征和边信息。节点特征可以是任意维度的向量,边信息可以是节点之间的连接关系。 2. 模型定义:接下来,需要定义GAT模型的网络结构。GAT模型主要由多个Graph Attention Layer组成,每个Attention Layer都有一个注意力权重计算机制,用于计算节点之间的注意力得分。在PyTorch中,可以使用torch.nn.Module类定义GAT模型,并在forward()方法中实现模型的前向传播计算。 3. 注意力计算:注意力机制是GAT模型的核心。在每个Attention Layer中,可以使用自定义函数或者使用PyTorch提供的函数,例如torch.nn.functional中的softmax()函数来计算节点之间的注意力得分。 4. 训练模型:定义好模型后,需要准备训练数据,并使用合适的优化器和损失函数对模型进行训练。在训练过程中,可以使用PyTorch提供的自动微分机制来计算梯度,并使用优化器来更新模型的参数。 5. 模型评估:训练完成后,可以使用测试数据对模型进行评估。可以计算模型的准确率、精确率、召回率等指标来评估模型的性能。 总结起来,GAT代码PyTorch中主要包括数据准备、模型定义、注意力计算、训练模型模型评估等步骤。通过使用PyTorch提供的函数和类,可以方便地实现GAT模型,并对图数据进行学习和预测。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值