基于稀疏矩阵计算实现Graph Attention Network

非稀疏矩阵的GAT由于如下代码的存在,因此空间复杂度非常高。但是对于大多数网络来说,边都是稀疏的,而GAT计算的是任意两个节点对之间的权重,因此许多计算都是没有必要的。

a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)

https://blog.csdn.net/qq_36618444/article/details/108099479 可以看稀疏矩阵计算所需要的环境。但是可能因为pytorch_sparse还在迭代更新中,每个版本的功能都不甚相同,因此在github上找了好几个版本都无法顺利运行,所以自己试着根据别人的代码改写了一个:

import math
import numpy as np
import time
import os
import config

import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch_sparse import spmm   # product between dense matrix and sparse matrix
import torch_sparse as torchsp
from torch_scatter import scatter_add, scatter_max
import torch.sparse as sparse

class SparseGATLayer(nn.Module):
    """
    Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, input_dim, out_dim, dropout, alpha, concat=True):
        super(SparseGATLayer, self).__init__()
        self.in_features = input_dim
        self.out_features = out_dim
        self.alpha = alpha
        self.concat = concat
        self.dropout = dropout
        self.W = nn.Parameter(torch.zeros(size=(input_dim, out_dim)))  # FxF'
        self.attn = nn.Parameter(torch.zeros(size=(1, 2 * out_dim)))  # 2F'
        nn.init.xavier_normal_(self.W, gain=1.414)
        nn.init.xavier_normal_(self.attn, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, x, adj):
        '''
        :param x:   dense tensor. size: nodes*feature_dim
        :param adj:    parse tensor. size: nodes*nodes
        :return:  hidden features
        '''
        N = x.size()[0]   # 图中节点数
        edge = adj._indices()   # 稀疏矩阵的数据结构是indices,values,分别存放非0部分的索引和值,edge则是索引。edge是一个[2*NoneZero]的张量,NoneZero表示非零元素的个数
        if x.is_sparse:   # 判断特征是否为稀疏矩阵
            h = torch.sparse.mm(x, self.W)
        else:
            h = torch.mm(x, self.W)
        # Self-attention (because including self edges) on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()  # edge_h: 2*D x E
        values = self.attn.mm(edge_h).squeeze()   # 使用注意力参数对特征进行投射
        edge_e_a = self.leakyrelu(values)  # edge_e_a: E   attetion score for each edge,对应原论文中的添加leakyrelu操作
        # 由于torch_sparse 不存在softmax算子,所以得手动编写,首先是exp(each-max),得到分子
        edge_e = torch.exp(edge_e_a - torch.max(edge_e_a))
        # 使用稀疏矩阵和列单位向量的乘法来模拟row sum,就是N*N矩阵乘N*1的单位矩阵的到了N*1的矩阵,相当于对每一行的值求和
        e_rowsum = spmm(edge, edge_e, m=N, n=N, matrix=torch.ones(size=(N, 1)).cuda())  # e_rowsum: N x 1,spmm是稀疏矩阵和非稀疏矩阵的乘法操作
        h_prime = spmm(edge, edge_e, n=N,m=N, matrix=h)   # 把注意力评分与每个节点对应的特征相乘
        h_prime = h_prime.div(e_rowsum + torch.Tensor([9e-15]).cuda())  # h_prime: N x out,div一看就是除,并且每一行的和要加一个9e-15防止除数为0
        # softmax结束
        if self.concat:
            # if this layer is not last layer
            return F.elu(h_prime)
        else:
            # if this layer is last layer
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

这里需要强调一下,spmm函数的源代码中的注释好像有错误。mn分别表示稠密矩阵的两个维度,但是代码中我设置了m=N,n=N,而稠密矩阵的size为N*1,这显然不符。一个可能的解释是m为稀疏矩阵的第一个维度,n为稠密矩阵的第一个维度,因为矩阵相乘中,稀疏矩阵的第二个维度和稠密矩阵的第一个维度必须相等,因此无关稠密矩阵的第二个维度。

"""Matrix product of sparse matrix with dense matrix.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
        matrix (:class:`Tensor`): The dense matrix.

    :rtype: :class:`Tensor`
    """
  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

五月的echo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值