我在训练模型时,用了一个标准CrossAttn处理(N,512,200,200)的矩阵,发现每次仅仅计算attention就会把12G的显存挤爆。后边我把head缩减为1,attn的embed_dim设为16,仍然一次会申请11G左右的资源,仍然会爆炸(这些还没算优化时的继续增长)。我模型本身的参数也就1M左右,为什么会爆呢?
这是因为矩阵太大了。Bert接受的最大长度也就512,我这个已经40000了,不爆炸才怪。ViT的改进也就是在于,把224*224的图片分成16*16的patch,每个patch当作一个word做pooling。
这其实涉及到中间activation的计算缓存问题,以下面为例,进行一些分析:
import torch
class MNET(torch.nn.Module):
def __init__(self):
super(MNET, self).__init__()
self.ln = torch.nn.Linear(3, 3,bias=True)
def forward(self, x):
print(id(x),torch.cuda.memory_allocated())
x = torch.add(x, 1)
print(id(x),torch.cuda.memory_allocated())
# x=self.ln(x)
x=x.transpose(1,2)
print(id(x),torch.cuda.memory_allocated())
return x
# grad == True 才会缓存
x = torch.rand(2,200, 3).cuda().requires_grad_(True)
model = MNET().cuda()
y = model(x)
'''
2748905790976 6144
2748906023568 11264
2748906132160 11264
'''
首先,缓存中间变量是为了求导。那么哪些操作需要缓存?
结论是:修改了内存数据本身,而且前序计算使用到了require_grad=True的leaf node
例如 y=WX,W需要grad,所以 dy/dW = X,需要保存X下来
例如transpose这种只改变了tensor元数据,不影响内存的操作,是不会缓存的。因为求导时的变换关系查表就能搞定,不需要任何新的数据。
x刚刚创建的时候,被分配内存了。
然后来到x=x+1。很多博客说过x=x+1和x+=1是不一样的,后者是inplace操作。实际上,就算前者不是inplace,如果不需要grad的话也不会copy副本。inplace可能带来梯度问题是因为id没变,不是因为没copy副本。我这里设了grad,所以从6144 变到了 11264!
x=x+1 | x+=1 | |
require grad | id不同,内存++ | id相同,内存不变 |
no grad | id不同,内存不变 | id相同,内存不变 |