训练 Transfomer 模型的内存消耗计算

经典图打底:

训练深度模型的内存消耗主要有以下几个部分:

  1. 存储模型可训练参数
  2. 存储梯度
  3. 存储反向传播中间变量,例如:

L = ( Y − Y ^ ) 2 Y ^ = X T W ∂ L ∂ W = − 2 ( Y − Y ^ ) ∂ Y ^ ∂ W = − 2 ( Y − X T W ) X \begin{aligned} L &= (Y - \hat Y)^2\\ \hat Y &= X^T W\\ \frac{\partial L}{\partial W}&= -2(Y-\hat Y) \frac{\partial \hat Y }{\partial W} = -2(Y- X^T W) X \end{aligned} LY^WL=(YY^)2=XTW=2(YY^)WY^=2(YXTW)X
这里面 X X X 就需要保存下来供反向传播时使用

下面具体的分析中需要用到每一层的具体运算张量,具体可以参考 Transfomer矩阵维度分析及MultiHead详解


model 内存

    """
    计算储存Transformer模型可训练参数所需的内存

    参数:
    - vocab_in_size: vocab_in大小
    - vocab_out_size: vocab_out大小
    - encoder_layers_num: 编码器层数
    - decoder_layers_num: 解码器层数
    - d_model: 编码器和解码器的隐藏层大小
    - num_head: 头的数量
    - embedding_size: 词嵌入大小
    - filter_size: 前馈子层的隐藏层大小
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - bias: 是否加偏置项
    - include_pos_embedding: 位置编码是否单独包含可优化参数
    - dropout_rate: 例如: 0.1
    - dtype_size: 默认为4 (FP32),若是FP16,改为2

    返回:
    - 所需内存,以字节为单位。
    """

    bias = bias * 1

    # 计算encoder embedding的参数内存消耗
    encoder_embedding_params = vocab_in_size * embedding_size

    # 计算 Encoder 的参数内存消耗
    # Multi-head Attention parameters: 3 * (d_model * d_model) + (d_model * d_model)
    # Layer normalization: d_model + d_model * bias
    # Feed-forward network parameters: d_model * filter_size + filter_size * d_model
    attention_params = 4 * d_model * d_model
    layer_norm_params = d_model + d_model * bias
    ffn_params_params = 2 * d_model * filter_size
    encoder_params = (attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * encoder_layers_num

    # 计算decoder embedding的参数内存消耗
    decoder_embedding_params = vocab_out_size * embedding_size
    # 计算 Decoder 的参数内存消耗
    # Masked Multi-head Attention parameters: 4 * (d_model * d_model)
    # Multi-head Attention parameters: 4 * (d_model * d_model)
    decoder_params = (attention_params + layer_norm_params + attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * decoder_layers_num

    # 计算最后 output 层的参数内存消耗
    output_params = d_model * vocab_out_size

    # 计算储存模型可训练参数所需内存,考虑 dropout_rate(近似估算)
    model_memory = (encoder_embedding_params + encoder_params + decoder_embedding_params + decoder_params + output_params) * (1 + dropout_rate) * dtype_size
    if include_pos_embedding:
        model_memory += seq_len * d_model * 2 # encoder 和 decoder 各有一个 pos embedding
     

gradients 内存

这里除了 gradients 内存,还考虑了一些小项,例如 mask,优化器 等消耗的内存

def get_inputs_mem(batch_size, seq_len, dtype_size=8):
    """
    计算Transformer模型输入数据的内存占用

    参数:
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - dtype_size: 默认为8 (int64)

    返回:
    - 所需内存,以字节为单位。
    """
    return batch_size * seq_len * dtype_size * 2  # 同时计算输入和输出

    # 计算attention中的mask的内存消耗
    # Mask: seq_len * seq_len for each attention block
    mask_memory = seq_len * seq_len * (encoder_layers_num + decoder_layers_num*2) * dtype_size

    # 计算gradients消耗的内存, 训练过程中的梯度与模型参数的形状相同,因此梯度的内存大小也是 model_memory
    grads_memory = model_memory

    # 计算优化器消耗的内存,此处以adam为例,对每一个可训练参数,需要储存一个一阶动量和一个二阶动量
    # 若使用的其他优化器,此处按需修改
    optimizer_memory = 2 * model_memory

    # 数据存储消耗的内存
    inputs_memory = get_inputs_mem(batch_size,seq_len)


activates 内存

    """
    计算中间结果(activates)的内存消耗,反向传播需要用到这些中间结果
    
     参数:
    - vocab_out_size: vocab_out大小
    - encoder_layers_num: 编码器层数
    - decoder_layers_num: 解码器层数
    - d_model: 编码器和解码器的隐藏层大小
    - num_head: 头的数量
    - filter_size: 前馈子层的隐藏层大小
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - dtype_size: 默认为4 (FP32),若是FP16,改为2

    返回:
    - 所需内存,以字节为单位。
    """

    # 由于各个layer的输入和输出size都是 batch_size * seq_len * d_model, 先计算出来后续使用
    N = batch_size * seq_len * d_model * dtype_size

    # 计算每层 attention 部分的中间结果内存消耗
    # 1.linear transformation: X*W_q = Q, X*W_k = K, X*W_v = V, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 X (只需储存一个,因为是同一个X)
    # 2.由于 Attention(Q,K,V) = softmax(QK^T/sqrt(d))V, 其中 QK^T 的张量为 [batch_size * num_head * seq_len * d_model/num_head] * [batch_size * num_head * d_model/num_head * seq_len] = [batch_size * num_head * seq_len * seq_len]
    # V 张量为 [batch_size * num_head * seq_len * d_model/num_head], 需要存储 Q, K, V, softmax(QK^T/sqrt(d))
    # 3.output linear transformation: Y = Attention(Q,K,V)*W_2, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 Attention(Q,K,V)
    linear_memory = N
    softmax_memory = 3 * N + batch_size * num_head * seq_len * seq_len * dtype_size
    output_memory = N
    attention_memory = linear_memory + softmax_memory + output_memory

    # 计算每层的 Layer normalization 的中间结果内存消耗, Layer normalization 输出张量为 batch_size * seq_len * d_model
    layer_norm_memory = N

    # 计算每层的 FFN 部分的中间结果内存消耗
    # 1.第一层 linear transformation: X*W_1 = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * filter_size] = [batch_size * seq_len * filter_size], 需储存 X
    # 2.中间 Relu 连接: Y' = Relu(Y), 需储存 Y'
    # 3.第二层 linear transformation: Y'*W_2 = Z, 张量为 [batch_size * seq_len * filter_size] * [batch_size * filter_size * d_model] = [batch_size * seq_len * d_model], 需储存 Y'
    ffn_memory = N + 2 * batch_size * seq_len * filter_size * dtype_size

    encoder_memory = (attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * encoder_layers_num
    decoder_memory = (attention_memory + layer_norm_memory + attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * decoder_layers_num

    # 计算 output 层的中间结果内存消耗
    # 1.output linear transformation: X*W = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * vocab_out_size] = [batch_size * seq_len * vocab_out_size], 需储存 X
    # 2.softmax(Y): 需储存 softmax(Y)
    output_memory = N + batch_size * seq_len * vocab_out_size * dtype_size

    total_activates_memory = encoder_memory + decoder_memory + output_memory

将上述三个部分加总,就是训练 Transfomer 模型大概需要的内存消耗。

NOTE:

  1. 这里没有考虑混合精度训练,如果考虑混合精度训练,还需要在不同的部分,使用不同的 dtype_size
  2. 如果是GPT这种 decoder-only 或者 encoder-only 的模型,只需要 decoder_layers_num = 0,即可 (decoder-only 也是这样做的,因为decoder-only 中的 Masked Multi-head Attention 没有了,实际的参数情况和 encoder-only 是一样的)

Reference:
Transformer Memory Arithmetic: Understanding all the Bytes in nanoGPT
Formula to compute approximate memory requirements of transformer models
Transformer Math 101

  • 15
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值