图注意力神经网络GAT 浅析,单层GAT layer 理解


class GATLayer(nn.Module):
    def __init__(self,g,in_dim,out_dim):
        super(GATLayer,self).__init__()
        self.g=g
        self.fc=nn.Linear(in_dim,out_dim,bias=False)
        self.attn_fc=nn.Linear(2*out_dim, 1, bias=False)  # alpha层,映射高维到一个单个的数。

    def edge_attention(self,edges):

        z2=torch.cat([edges.src['z'],edges.dst['z']],dim=1)
        a=self.attn_fc(z2)
        return {'e':F.leaky_relu(a)}

    def message_func(self,edges): 
        return  {'z':edges.src['z'],'e':edges.data['e']}
    
    def reduce_func(self,nodes):
        alpha = F.softmax(nodes.mailbox['e'],dim=1)
        h=torch.sum(alpha * nodes.mailbox['z'],dim=1)
        return {'h': h}
    
    def forward(self,h):
        z=self.fc(h)
        self.g.ndata['z']=z
        self.g.apply_edges(self.edge_attention)  # 调用参数中的函数来更新边的features,即用src点节和dst节点的z特征拼接后,求attention值 
        self.g.update_all(self.message_func, self.reduce_func)  #通过所有的边来发送消息并更新所有节点
        return self.g.ndata.pop('h') # 移除ndata的h数据的最后一个值并返回

图注意力神经网络中的注意力: 注意力其实就是加权求和中的权重。 在图结构中,注意力是指一个节点的邻居节点对它的重要性(权重大小)。在图中,一个节点的邻居对它的作用并不是平等的,而且也不是对称的。比如 节点2->节点3的权重大,不代表反过来  节点3->节点2的权重也大(有可能节点3对节点2并不重要)。举个例子: 酵母对于面粉的作用很大,放一点进去就能发面,但是把少量面粉放酵母并没有啥作用,即面粉对酵母并没有太多的作用(权重小)。

 

 

 

完整的GAT分为三层: GAT layer (单层GAT), MultiHeadGATLayer, GA模型(多层MultiheadGATLayer)

如果拿CNN作为例子,对照关系就是:

GATLayer = 单个卷积核 

MultiheadGATLayer = 多通道卷积

GAT模型就是CNN整个模型了,由多层的多通道卷积组成。

关于比较好的 GAT描述,请参见 : 向往的GAT(图注意力模型) - 知乎 (zhihu.com)

 

GAT的核心在于计算注意力系数,看这几行

 

   def edge_attention(self,edges):

        z2=torch.cat([edges.src['z'],edges.dst['z']],dim=1)   #对应下面公式中的 || 运算(拼接)。
        a=self.attn_fc(z2)      #单层前馈神经网络,对应的就是alpha操作
        return {'e':F.leaky_relu(a)}

 

公式如下:

 

alpha是一个单层神经前馈神经网络。最后 alpha 把拼接后的高维特征映射到一个实数eij上。

其中的W就是self.attn_fc函数中的权重W了(由全连接网络训练得到)。

公式2: 

对应: 

    return {'e':F.leaky_relu(a)}

非线性函数,这里就是leaky_relu了(原始论文中使用的,据说是经验得出,并没有太多理论解释)。

到于softmax,看下面的代码:

   def reduce_func(self,nodes):
        alpha = F.softmax(nodes.mailbox['e'],dim=1)
        h=torch.sum(alpha * nodes.mailbox['z'],dim=1)
        return {'h': h}

连起来看函数  edge_attention  和 reduce_func, 就能把两个公式连起来了。

 

GAT中的A指的是attention, 主要就是注意力系统数,这是一个标量。核心就是算每条边的这个标量值 。

 

上面是粗浅理解 ,不对的地儿请大伙儿指出。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值