点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
编者荐语
Transformer 有着巨大的内存和算力需求,因为它构造了一个注意力矩阵,需求与输入呈平方关系。文章为你介绍神经网络的内存计算方法。
在微调GPT/BERT模型时,会经常遇到“ cuda out of memory”的情况。这是因为transformer是内存密集型的模型,并且内存要求也随序列长度而增加。所以如果能对模型的内存要求进行粗略的估计将有助于估计任务所需的资源。
如果你想直接看结果,可以跳到本文最后。不过在阅读本文前请记住所有神经网络都是通过反向传播的方法进行训练的, 这一点对于我们计算内存的占用十分重要。
total_memory = memory_modal + memory_activations + memory_gradients
这里的memory_modal是指存储模型所有参数所需的内存。memory_activations是计算并存储在正向传播中的中间变量,在计算梯度时需要使用这些变量。因为模型中梯度的数量通常等于中间变量的数量,所以memory_activations= memory_gradients。因此可以写成:
total_memory = memory_modal + 2 * memory_activations
所以我们计算总体内存的要求时只需要找到memory_modal和memory_activations就可以了。
估算模型的内存
下面我们以GPT为例。GPT由许多transformer块组成(后面我用n_tr_blocks表示其数量)。每个transformer块都包含以下结构:
multi_headed_attention --> layer_normalization --> MLP -->layer_normalization
每个multi_headed_attention元素都由键,值和查询组成。其中包括n_head个注意力头和dim个维度。MLP是包含有n_head * dim的尺寸。这些权重都是要占用内存的,那么:
memory_modal = memory of multi_headed_attention + memory of MLP
= memory of value + memory of key + memory of query + memory of MLP
= square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
= 4*square_of(n_head * dim)
因为我们的模型包含了n个单元。所以最后内存就变为:
memory_modal = 4*n_tr_blocks*square_of(n_head * dim)
上面的估算没有考虑到偏差所需的内存,因为这大部分是静态的,不依赖于批大小、输入序列等。
估算中间变量的内存
多头注意力通常使用softmax,可以写成:
multi_headed_attention = softmax(query * key * sequence_length) * value
k,q,v的维度是:
[batch_size, n_head, sequence_length, dim]
multi_headed_attention操作会得出如下形状:
[batch_size, n_head, sequence_length, sequence_length]
所以最终得内存为:
memory_softmax = batch_size * n_head * square_of(sequence_length)
q* k * sequence_length操作乘以value的形状为[batch_size, n_head, sequence_length, dim]。MLP也有相同的维度:
memory of MLP = batch_size * n_head * sequence_length * dim
memory of value = batch_size * n_head * sequence_length * dim
我们把上面的整合在一起,单个transformer的中间变量为:
memory_activations = memory_softmax + memory_value + memory_MLP
= batch_size * n_head * square_of(sequence_length)
+ batch_size * n_head * sequence_length * dim
+ batch_size * n_head * sequence_length * dim
= batch_size * n_head * sequence_length * (sequence_length + 2*dim)
再乘以块的数量,模型所有的memory_activations就是:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
整合在一起
我们把上面两个公式进行归纳总结,想看结果的话直接看这里就行了。transformer模型所需的总内存为:
total_memory = memory_modal + 2 * memory_activations
模型参数的内存:
4*n_tr_blocks*square_of(n_head * dim)
中间变量内存:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
我们使用下面的符号可以更简洁地写出这些公式。
R = n_tr_blocks = transformer层堆叠的数量
N = n_head = 注意力头数量
D = dim = 注意力头的维度
B = batch_size = 批大小
S = sequence_length =输入序列的长度
memory modal = 4 * R * N^2 * D^2
memory activations = RBNS(S + 2D)
所以在训练模型时总的内存占用为:
M = (4 * R * N^2 * D^2) + RBNS(S + 2D)
因为内存的占用和序列长度又很大的关系,如果有一个很长的序列长度S >> D S + 2D <——> S,这时可以将计算变为:
M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2
可以看到对于较大的序列,M与输入序列长度的平方成正比,与批大小成线性比例,这也就证明了序列长度和内存占用有很大的关系。
所以最终的内存占用的评估为:
总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)
好消息!
小白学视觉知识星球
开始面向外开放啦👇👇👇
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~