深入理解图注意力网络代码(GAT)

目录

一、整体代码

二、解释 拼接操作

一)创建所有可能的配对

二)拼接以形成配对

三)示例

①假设 out_features 为 2,我们的序列 h 为

②h.repeat(1, N) 生成的矩阵将是

③这里插播一个pytorch中view的使用方法

④因此h.repeat(1, N).view(N*N, -1) 的输出结果是

⑤h.repeat(N, 1) 生成的矩阵将是

⑥torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).拼接后的矩阵是


一、整体代码

​
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 
    图注意力层
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features   # 节点表示向量的输入特征维度
        self.out_features = out_features   # 节点表示向量的输出特征维度
        self.dropout = dropout    # dropout参数
        self.alpha = alpha     # leakyrelu激活的参数
        self.concat = concat   # 
        
        # 定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  
        nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier初始化
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)   # xavier初始化
        
        # 定义leakyrelu激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, inp, adj):
        """
        inp: input_fea [N, in_features]  in_features表示节点的输入特征向量元素个数
        adj: 图的邻接矩阵 维度[N, N] 非零即一,数据结构基本知识
        """
        h = torch.mm(inp, self.W)   # [N, out_features]
        N = h.size()[0]    # N 图的节点数
        
        a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)
        # [N, N, 2*out_features]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        # [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
        
        zero_vec = -1e12 * torch.ones_like(e)    # 将没有连接的边置为负无穷
        attention = torch.where(adj>0, e, zero_vec)   # [N, N]
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留,
        # 否则需要mask并置为非常小的值,原因是softmax的时候这个最小值会不考虑。
        attention = F.softmax(attention, dim=1)    # softmax形状保持不变 [N, N],得到归一化的注意力权重!
        attention = F.dropout(attention, self.dropout, training=self.training)   # dropout,防止过拟合
        h_prime = torch.matmul(attention, h)  # [N, N].[N, out_features] => [N, out_features]
        # 得到由周围节点通过注意力权重进行更新的表示
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime 
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

​

二、解释 拼接操作

也就是下面这个代码:

a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)

它构建了一个特定的数据结构,这个结构允许每个元素都能与序列中的每个其他元素配对。

一)创建所有可能的配对

h.repeat(1, N).view(N*N, -1)

这个操作重复了 h 中的每个元素 N 次,创建了一个包含所有可能的“行”配对的张量。例如,如果 h 是一个序列中的元素,这个操作就创建了一个包含这个元素与序列中每个元素(包括它自己)的配对的张量。

h.repeat(N, 1)

这个操作重复整个 h N 次,创建了一个包含所有可能的“列”配对的张量。

二)拼接以形成配对

通过 torch.cat,这两个重复的张量被沿着特定的维度拼接起来。这意味着对于 h 中的每个元素,你现在有了一个包含了它与序列中每个其他元素的配对的完整集合。

三)示例

①假设 out_features 为 2,我们的序列 h
h1​=[1,2]

h2​=[3,4]

h3​=[5,6]

②h.repeat(1, N) 生成的矩阵将是
[1,2],[1,2],[1,2] 

[3,4],[3,4],[3,4]

[5,6],[5,6],[5,6]
③这里插播一个pytorch中view的使用方法
#初始化一个tensor
import torch
a1 = torch.arange(0,16)
print(a1)

#输出为:tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)

print(a2)
print(a3)
print(a4)

输出为:

tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
④因此h.repeat(1, N).view(N*N, -1) 的输出结果是
 
[1,2]
[1,2]
[1,2]
[3,4]
[3,4]
[3,4]
[5,6]
[5,6]
[5,6]
h.repeat(N, 1) 生成的矩阵将是
[1,2] 
[3,4]
[5,6]
[1,2] 
[3,4]
[5,6]
[1,2] 
[3,4]
[5,6]
⑥torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).拼接后的矩阵是
[1,2], [1,2]
[1,2], [3,4]
[1,2], [5,6]
[3,4], [1,2]
[3,4], [3,4]
[3,4], [5,6]
[5,6], [1,2]
[5,6], [3,4]
[5,6], [5,6]

这个矩阵包含了序列中每个元素对(例如 [h_1, h_1], [h_1, h_2], [h_1, h_3] 等)的组合。这为计算自注意力机制中的每个元素与序列中其他所有元素之间的关系提供了基础。每一行代表一个唯一的元素配对,这使得模型能够针对每个配对计算注意力分数。

  • 7
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值