GAT网络为什么占用那么多的显存

GAT网络显存占用原因分析(非专业,纯个人理解)

本blog主要聚焦以下几点:

  • GAT为什么占用那么多的显存,尤其是在节点多的时候
  • Transformer为什么就可以处理很长的输入,GAT就不行?

⚠⚠:并非图相关的方向,仅是打工时用到了,简单了解了一下。

1:GAT的注意力机制

GAT占用显存多的原因一句话概括就是:注意力机制

首先来看GAT的注意力机制是什么样的:
在这里插入图片描述

图片先随便放一张,之后再改:来源(https://zhuanlan.zhihu.com/p/137592078)

也就是,每一个节点可以用一个向量 h i h_i hi表示,那么在计算attention score的时候,就是 W h i Wh_i Whi向量与 W h j Wh_j Whj向量拼接之后( W W W是去可学习的权重),再乘一个可学习权重 a a a。,之后再求softmax,并且使用了mask,让每个节点只使用和他直接相连的节点进行计算。

其实很容易理解,就是先把每一个节点的向量过一个mlp,之后使用每个节点得到的新向量,计算每两个节点之间的加权求和。这就是主要的费显存的地方。为什么呢?请看这部分的代码是如何实现的:

GAT的代码我直接在网上找的一个,改成了batch版本,代码来源:https://zhuanlan.zhihu.com/p/128072201

    def forward(self, inp, adj):
        """
        inp: input_fea [B, N, in_features]  in_features表示节点的输入特征向量元素个数
        adj: 图的邻接矩阵 维度[B, N, N] 非零即一,数据结构基本知识
        """

        h = torch.matmul(inp, self.W)  # [B, N, out_features]
        # print(h.size())
        print('1: ', torch.cuda.memory_allocated())
        N = h.size()[1]  # N 图的节点数
        # 最占显存的就是这下面一行
        # 单独一个a_input就会占用276.39兆的显存,而实际上这一个计算之后立马就会新增0.81G的显存占用
        a_input = torch.cat([h.repeat(1, 1, N).view(-1, N * N, self.out_features),
                             h.repeat(1, N, 1)], dim=1).view(-1, N, N, 2 * self.out_features)
        print('2: ', torch.cuda.memory_allocated())

其中,在计算 a i n p u t a_input ainput的时候计算的就是 ( W h i , W h j ) (Wh_i, Wh_j) (Whi,Whj),为了便于快速计算出 N N N个节点中两两之间的attention score,要么就两层循环嵌套,一共循环 N 2 N^2 N2次,要么直接将矩阵 W h i Wh_i Whi W h j Wh_j Whj在不同的维度上重复 N N N次,然后拼接。repeat的效果大概是下面这样:
在这里插入图片描述
同一个颜色代表同一个向量,左边是把每一个重复 N N N次,右边是把整体重复 N N N次。拼接起来之后,刚刚好每一行就是一个向量 h i h_i hi和另一个向量 h j h_j hj,之后在乘一个矩阵 a a a,直接就计算出了两个向量之间的加权求和值。

而恰恰就是这个向量导致整体的显存占用暴涨。我的场景是有一个500+节点的图,每一个图的特征是6,我设置的 o u t _ f e a t u r e s out\_features out_features数目是128。所以我上面repeat之前的矩阵大小是 [ B , 500 , 128 ] [B, 500, 128] [B,500,128],重复之后变成了 [ B , 500 , 500 , 256 ] [B, 500, 500, 256] [B,500,500,256],此时我把B设为4。那么一共有 4 ∗ 500 ∗ 500 ∗ 256 = 256000000 4 * 500 * 500 * 256 = 256000000 4500500256=256000000个数字,而一个单精度的foat32的数占用四个字节,因此该tensor占用 256000000 / 256 / 1024 / 1024 = 0.95 G B 256000000 / 256 / 1024 / 1024 = 0.95GB 256000000/256/1024/1024=0.95GB,而我们还会设置多头注意力,不过没经过一个头,算完之后,这个显存会被回收,所以八个头其实占用和一个头差不太多。

到这里虽然占用比较多,但是还在可接受的范围之内,因为我们的显卡经常都是32G的显或者40G的显存。

但是,当把八个头的输出叠加到一起之后,是一个 [ B , N , o u t _ f e a t u r e s ∗ 8 ] [B, N, out\_features*8] [B,N,out_features8]大小的矩阵,此时再加一个输出的GAT头,因为我还有1600个类别,所以输出头的 o u t _ f e a t u r e s out\_features out_features就是1600。此时再继续按照上面的方法来算的话,就会发现,完蛋了。

我们的显存占用将会是 [ B , N , N , 3200 ] [B, N, N, 3200] [B,N,N,3200],占用显存直接变为 4 ∗ 500 ∗ 500 ∗ 3200 / 256 / 1024 / 1024 = 11.9 G B 4 * 500 * 500 * 3200 / 256 / 1024 / 1024 = 11.9GB 45005003200/256/1024/1024=11.9GB,直接就变得非常大了,而这只是理想情况下,实际上,pytorch会分配更多的显存。 因此,bs必须非常小才能勉强维持不爆显存。

2:为什么transformer就可以有很长的输入

这个也很简单,因为transformer使用的注意力机制是Scaled Dot-Product Attention

它不需要对tensor进行重复操作,只需要直接dot-product就行了。

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值