Pytorch 显存机制 (以及为什么 x=x+1没有显存增长)

我在训练模型时,用了一个标准CrossAttn处理(N,512,200,200)的矩阵,发现每次仅仅计算attention就会把12G的显存挤爆。后边我把head缩减为1,attn的embed_dim设为16,仍然一次会申请11G左右的资源,仍然会爆炸(这些还没算优化时的继续增长)。我模型本身的参数也就1M左右,为什么会爆呢?

这是因为矩阵太大了。Bert接受的最大长度也就512,我这个已经40000了,不爆炸才怪。ViT的改进也就是在于,把224*224的图片分成16*16的patch,每个patch当作一个word做pooling。

这其实涉及到中间activation的计算缓存问题,以下面为例,进行一些分析:

import torch

class MNET(torch.nn.Module):
    def __init__(self):
        super(MNET, self).__init__()
        self.ln = torch.nn.Linear(3, 3,bias=True)
    def forward(self, x):
        print(id(x),torch.cuda.memory_allocated())
        x = torch.add(x, 1)
        print(id(x),torch.cuda.memory_allocated())
        # x=self.ln(x)
        x=x.transpose(1,2)
        print(id(x),torch.cuda.memory_allocated())

        return x

# grad == True 才会缓存
x = torch.rand(2,200, 3).cuda().requires_grad_(True) 
model = MNET().cuda()
y = model(x)

'''

2748905790976 6144
2748906023568 11264
2748906132160 11264

'''

首先,缓存中间变量是为了求导。那么哪些操作需要缓存?

        结论是修改了内存数据本身,而且前序计算使用到了require_grad=True的leaf node

        例如  y=WX,W需要grad,所以 dy/dW = X,需要保存X下来 

        例如transpose这种只改变了tensor元数据,不影响内存的操作,是不会缓存的。因为求导时的变换关系查表就能搞定,不需要任何新的数据。

        x刚刚创建的时候,被分配内存了。

        然后来到x=x+1。很多博客说过x=x+1和x+=1是不一样的,后者是inplace操作。实际上,就算前者不是inplace,如果不需要grad的话也不会copy副本。inplace可能带来梯度问题是因为id没变,不是因为没copy副本。我这里设了grad,所以从6144 变到了 11264!

x=x+1x+=1
require gradid不同,内存++id相同,内存不变
no gradid不同,内存不变id相同,内存不变

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值