GAT源码剖析

题解:

在这里插入图片描述
首先挂出核心公式和训练过程生成的aij注意系数

layer.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout #dropout 参数
        self.in_features = in_features # 输入的特征
        self.out_features = out_features # 输出特征
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        '''
            先torch.empty创建的是一个size 大小的 torch.Tensor类型 这个类型是不可训练的
            然后使用Parameter命令对 原来的Parameter类型进行绑定并且转化为Parameter 类型
            Parameter 是一个可训练的类型
        '''

        nn.init.xavier_uniform_(self.W.data, gain=1.414) # gain 是两种方法计算的中 a 和 std 计算的重要参数
        '''
            Xavier 是一种初始化的方式 pytroch 提供了uniform 和  normal 两种方式:
            nn.init.xavier_uniform_(tensor, gain =1) 是均匀分布 (-a,a)
            nn.init.xavier_normal_(tensor, gain=1) 正态分布~N ( 0 ,std )
            https://blog.csdn.net/dss_dssssd/article/details/83959474 讲解博客地址
        '''

        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        '''
                    论文中有一个公式whi||whj 
                    || 是连接符号 通过这个连接,我们把两个1 X F 的矩阵变成了 一个 1 X 2F 的矩阵
                    然后论文中乘以了一个a 2F X 1 的矩阵 那么就得到了一个数
                    这个a就是论文中的那个a
                    而我们的得到的那个数就是我们的attention系数 
                '''

        self.leakyrelu = nn.LeakyReLU(self.alpha)
        '''
            因为原式是 aij= softmax(sigmod(a (whi||whj)))
            sigmod 是激活函数 这里用的是leaky ReLU 函数
        '''
    '''
        forward 和 _prepare_attentional_mechanism_input 对应的论文 2.1环节
    '''
    def forward(self, h, adj): # 正向传播
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        # 这里是做一个乘法
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0] # number of nodes

        '''
            Wh size = [2708,8] 8 是标签数目
            
            size()[0] size顾名思义是大小的意思 [] 的应用是在维度上
            这里[0]代表的是在0维度 
            例如我们创建一个torch.rand([2, 1, 3, 3])
            那么 size()[1] 就等于 1
            
        '''
        # 下面,创建了两个矩阵,它们在行中的嵌入顺序不同
        # (e stands for embedding) e 是 embedding 的基础
        # 这些是第一个矩阵的行 (Wh_repeated_in_chunks):
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times
        # 
        # 这些是第二个矩阵的行 (Wh_repeated_alternating):
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 
        # '----------------------------------------------------' -> N times
        # 
        
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        '''
            repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None)  是复制函数
            参数说明: 
            self: 传入 的数据为 tensor 
            repeats : 复制到几份
            dim : 要复制的维度 可以设定为 012
            Examples:
            此处定义了一个4维tensor,要对第2个维度复制,由原来的1变为3,即将设定dim=1。
            data1 = torch.rand([2, 1, 3, 3])
            data2 = torch.repeat_interleave(data1, repeats=3, dim=1)
        '''
        Wh_repeated_alternating = Wh.repeat(N, 1)
        '''
            repeat 函数:
            第一个参数是复制的份数
            第二个参数是复制的维度是那个维度
            Examples1:
                data1 = np.array([[1,2,3],
                     [4,5,6]])
                data1.repeat(2,0)
                array([[1, 2, 3],
               [1, 2, 3],
               [4, 5, 6],
               [4, 5, 6]])
            Examples2:
                data1 = np.array([[1,2,3],
                 [4,5,6]])
                data1.repeat(2,1)
                array([[1, 1, 2, 2, 3, 3],
               
  • 5
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值