GAT(参数中加入batch)

实现带有batch的GAT

作者实现的是不带batch的,踩了一点坑,自己改了一下带batch的

from torch import nn
import torch
import torch.nn.functional as F
class nconv(nn.Module):
    def __init__(self):
        super(nconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('ncvl,vw->ncwl', (x, A))
        return x.contiguous()

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 
    图注意力层
    input: (B,N,C_in)
    output: (B,N,C_out)
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features   # 节点表示向量的输入特征数
        self.out_features = out_features   # 节点表示向量的输出特征数
        self.dropout = dropout    # dropout参数
        self.alpha = alpha     # leakyrelu激活的参数
        self.concat = concat   # 如果为true, 再进行elu激活
        
        # 定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  
        nn.init.xavier_uniform_(self.W.data, gain=1.414)  # 初始化
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)   # 初始化
        
        # 定义leakyrelu激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, inp, adj):
        """
        inp: input_fea [B,N, in_features]  in_features表示节点的输入特征向量元素个数
        adj: 图的邻接矩阵  [N, N] 非零即一,数据结构基本知识
        """
        h = torch.matmul(inp, self.W)   # [B, N, out_features]
        N = h.size()[1]    # N 图的节点数

        a_input = torch.cat([h.repeat(1,1,N).view(-1, N*N, self.out_features), h.repeat(1, N, 1)], dim=-1).view(-1, N, N, 2*self.out_features)
        # [B, N, N, 2*out_features]
        
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))

        # [B, N, N, 1] => [B, N, N] 图注意力的相关系数(未归一化)
        
        zero_vec = -1e12 * torch.ones_like(e)    # 将没有连接的边置为负无穷


        attention = torch.where(adj>0, e, zero_vec)   # [B, N, N]
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留,
        # 否则需要mask并置为非常小的值,原因是softmax的时候这个最小值会不考虑。
        attention = F.softmax(attention, dim=1)    # softmax形状保持不变 [B, N, N],得到归一化的注意力权重!
        attention = F.dropout(attention, self.dropout, training=self.training)   # dropout,防止过拟合
        h_prime = torch.matmul(attention, h)  # [B, N, N].[B, N, out_features] => [B, N, out_features]
        # 得到由周围节点通过注意力权重进行更新的表示
        if self.concat:
            return F.relu(h_prime)
        else:
            return h_prime 
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, n_feat, n_hid, n_class, dropout, alpha, n_heads):
        """Dense version of GAT
        n_heads 表示有几个GAL层,最后进行拼接在一起,类似self-attention
        从不同的子空间进行抽取特征。
        """
        super(GAT, self).__init__()
        self.dropout = dropout 
        
        # 定义multi-head的图注意力层
        self.attentions = [GraphAttentionLayer(n_feat, n_hid, dropout=dropout, alpha=alpha, concat=True) for _ in range(n_heads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)   # 加入pytorch的Module模块
        # 输出层,也通过图注意力层来实现,可实现分类、预测等功能
        self.out_att = GraphAttentionLayer(n_hid * n_heads, n_class, dropout=dropout,alpha=alpha, concat=False)
    
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)   # dropout,防止过拟合
        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)  # 将每个head得到的表示进行拼接
        x = F.dropout(x, self.dropout, training=self.training)   # dropout,防止过拟合
        print(x.shape)
        x = F.elu(self.out_att(x, adj))   # 输出并激活
        return F.log_softmax(x, dim=2)  # log_softmax速度变快,保持数值稳定



x=torch.randn(32,32,5,12)
adj=torch.tensor([
    [0,1,0,1,1],
    [1,0,1,0,0],
    [0,1,0,1,0],
    [1,0,1,0,1],
    [1,0,0,1,0]
],dtype=torch.float32)
w=torch.randn(32*12,32*12)
x=x.permute(0,2,1,3)
x=torch.reshape(x,(32,5,-1))

net=GAT(32*12,32*12,64,0.1,1,2)
out=net(x,adj)

print(out.shape)
  • 15
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 32
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 32
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值