深入浅出GAT–Graph Attention Networks(图注意力模型)

深入浅出GAT–Graph Attention Networks(图注意力模型)

1 GAT的诞生

由深度学习三巨头之一的Yoshua Bengio组提出了Graph Attention Networks(下述简称为GAT)去解决目前GCN存在的问题。

2 关于图

为了更好的引入,我们先来看看图的基础知识~~~
哒哒哒哒哒哒哒~~~~~~~~~~~~~~~~~~~~~~~~~

2.1 图的重要的“两特征”

提及graph,通常是包含着顶点和边的关系。
在这里插入图片描述

  • 第一种特征:对于任意一个顶点 i,它在图上邻居 Ni,构成第一种特征,即图的结构关系。

  • 第二种特征:除了图的结构之外,每个顶点还有自己的特征 hi(通常是一个高维向量)。它可以使社交网络中每个用户的个体属性;可以是生物网络中,每个蛋白质的性质;还可以使交通路网中,每个交叉口的车流量。

2.2 GCN的局限性

GCN是处理transductive任务的一把利器(transductive任务是指:训练阶段与测试阶段都基于同样的图结构),然而GCN有两大局限性是经常被诟病的:

(a)无法完成inductive任务,即处理动态图问题。inductive任务是指:训练阶段与测试阶段需要处理的graph不同。通常是训练阶段只是在子图(subgraph)上进行,测试阶段需要处理未知的顶点。(unseen node)

(b)处理有向图的瓶颈,不容易实现分配不同的学习权重给不同的neighbor。这一点在前面的文章中已经讲过了,不再赘述,如有需要可以参考下面的链接。
解读三种经典GCN中的Parameter Sharing

3 GAT

重点在获取其余节点对本节点的影响上。GAT本质上有两种运算方式,即Mask graph attention or global graph attention。

3.1 Global graph attention

**顾名思义,就是每一个顶点i都对于图上任意顶点都进行attention运算。**可以理解为图1的蓝色顶点对于其余全部顶点进行一遍运算。

优点:完全不依赖于图的结构,对于inductive任务无压力

缺点:(1)丢掉了图结构的这个特征,无异于自废武功,效果可能会很差(2)运算面临着高昂的成本

3.2 Mask graph attention

注意力机制的运算只在邻居顶点上进行,也就是说图1的蓝色顶点只计算和橙色顶点的注意力系数。目前常用方式。

4 输入浅出GAT

叮叮叮叮叮~~~~~~~~
重点来啦!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
和所有的attention mechanism一样,GAT的计算也分为两步走:

4.1 计算注意力系数(attention coefficient)

  • 对于顶点i,逐个计算它的邻居们j和它自己之间的相似系数
    在这里插入图片描述
    解读一下这个公式:首先一个共享参数 w 的线性映射对于顶点的特征进行了增维,当然这是一种常见的特征增强(feature augment)方法;|| 对于顶点i,j的变换后的特征进行了拼接(concatenate);最后 a( )把拼接后的高维特征映射到一个实数上。
  • 有了相关系数,离注意力系数就差归一化了!其实就是用个softmax:
    在这里插入图片描述

4.2 加权求和(aggregate)

现在已经成功一大半了哦~~~~~~~~~
第二步很简单,根据计算好的注意力系数,把特征加权求和(aggregate)一下。

  • GAT输出的对于每个顶点 [公式] 的新特征(融合了邻域信息)
    在这里插入图片描述
  • 式(3)看着还有点单薄,俗话说一个篱笆三个桩,attention得靠multi-head帮!来进化增强一下
    在这里插入图片描述

5 深入理解强化GAT

5.1 与GCN的联系与区别

本质上而言:GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算),利用graph上的local stationary学习新的顶点特征表达。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为顶点特征之间的相关性被更好地融入到模型中。

5.2 为什么GAT适用于有向图?

最根本的原因是GAT的运算方式是逐顶点的运算(node-wise),这一点可从公式(1)—公式(3)中很明显地看出。每一次运算都需要循环遍历图上的所有顶点来完成。逐顶点运算意味着,摆脱了拉普利矩阵的束缚,使得有向图问题迎刃而解。

5.3 为什么GAT适用于inductive任务?

GAT中重要的学习参数是 W 与 a( ) ,因为上述的逐顶点运算方式,这两个参数仅与1.1节阐述的顶点特征相关,与图的结构毫无关系。所以测试任务中改变图的结构,对于GAT影响并不大。
与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。学习的参数很大程度与图结构相关,这使得GCN在inductive任务上遇到困境。

5.4 GAT在AI医药的应用?

论文链接:Bi-Level Graph Neural Networks for Drug-Drug Interaction Prediction
以后看到再补充~

  • 21
    点赞
  • 131
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
以下是使用PyTorch实现GAT的代码示例: ``` python import torch import torch.nn as nn import torch.nn.functional as F class GATLayer(nn.Module): def __init__(self, in_dim, out_dim): super(GATLayer, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim))) self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1))) nn.init.xavier_uniform_(self.W.data, gain=1.414) nn.init.xavier_uniform_(self.a.data, gain=1.414) def forward(self, h, adj): Wh = torch.mm(h, self.W) a_input = self._prepare_attentional_mechanism_input(Wh) e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(2)) zero_vec = -9e15*torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) h_prime = torch.matmul(attention, Wh) return h_prime def _prepare_attentional_mechanism_input(self, Wh): N = Wh.size()[0] Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) Wh_repeated_alternating = Wh.repeat(N, 1) all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1) return all_combinations_matrix.view(N, N, 2*self.out_dim) class GAT(nn.Module): def __init__(self, n_feat, n_hid, n_class, dropout, alpha, n_heads): super(GAT, self).__init__() self.dropout = dropout self.attentions = [GATLayer(n_feat, n_hid) for _ in range(n_heads)] for i, attention in enumerate(self.attentions): self.add_module('attention_{}'.format(i), attention) self.out_att = GATLayer(n_hid*n_heads, n_class) self.alpha = alpha def forward(self, x, adj): x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, adj) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) x = F.elu(self.out_att(x, adj)) return F.log_softmax(x, dim=1) ``` 在此示例中,我们实现了一个包含多头注意力机制的GAT模型。其中,GATLayer是GAT的核心组件,每个GATLayer都包含一个注意力头。在GAT模型中,我们将多个注意力头的输出连接在一起,再通过一个输出层进行分类。在forward函数中,我们首先对输入进行dropout,然后通过多个GATLayer进行特征提取,最后通过输出层进行分类并使用log_softmax进行预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值