文章翻译自VADIM IRTLAC大佬的文章,用于解决模型训练时显存和速度的优化。原文请👇
Optimization approaches for Transformers | Kaggle
速览
原文是用于transformer 训练时的显存优化,但里面很多方法cnn等网络都适用,本文从中节选了部分通用(transformer, cnn)的一些显存优化方法,下面表格对各类方法进行简单的总结:
方法简介
1.梯度累加
- 小batch_size训练: 显存占用少,速度慢,收敛慢,效果差,梯度下降算法在batch小时更敏感
- 大batch_size训练: 显存占用大,速度快,收敛快,效果好,
因此为了模型更好地收敛和提高训练速度,更希望使用大batch_size进行训练,大batch_size需要更大的显存。因此想到了用多次小batch_size反向传播累加梯度模拟一次大batch_size。思路如下图,四个小batch_size模拟一次大batch_size。
图1
代码:
steps = len(loader)
# perform validation loop each `validation_steps` training steps!
validation_steps = int(validation_steps * gradient_accumulation_steps)
for step, batch in enumerate(loader, 1):
# prepare inputs and targets for the model and loss function respectively.
# forward pass
outputs = model(inputs)
# computing loss
loss = loss_fn(outputs, targets)
# accumulating gradients over steps
if gradient_accumulation_steps > 1: #对照图1,gradient_accumulation_steps=4
loss = loss / gradient_accumulation_steps #注意损失需要求平均
# backward pass
loss.backward()
# perform optimization step after certain number of accumulating steps and at the end of epoch
if step % gradient_accumulation_steps == 0 or step == steps: #当到达step的时候才进行网络的梯度更新
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
model.zero_grad()
# perform validation loop
if step % validation_steps == 0:
validation_loop()
2.冻结
冻结freeze是一种加速训练和减少内存的有效方法。众所周知,深度学习在网络低层学习通用的数据模式,同时最高层次学习高级特性用于目标的任务。因此有一个在大数据集上训练的较好模型,浅层参数是可复用的;此外,当执行优化算法(例如SGD, AdamW或RMSprop),较低的层接收小的梯度较小,叫做梯度消失,因此对于这些网络层我们可以冻结。
PyTorch提供了一个API来控制梯度是否计算,该参数是:requires_grad 。该参数为False表示不计算梯度,True表示计算梯度。
代码
def freeze(module):
"""
Freezes module's parameters.
"""
for parameter in module.parameters():
parameter.requires_grad = False
def get_freezed_parameters(module):
"""
Returns names of freezed parameters of the given module.
"""
freezed_parameters = []
for name, parameter in module.named_parameters():
if not parameter.requires_grad:
freezed_parameters.append(name)
return freezed_parameters
import torch
from transformers import AutoConfig, AutoModel
# initializing model
model_path = "microsoft/deberta-v3-base"
config = AutoConfig.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, config=config)
# freezing embeddings and first 2 layers of encoder
freeze(model.embeddings) #冻结embedding层
freeze(model.encoder.layer[:2]) #这里冻结encoder网络的前两层
freezed_parameters = get_freezed_parameters(model)
print(f"Freezed parameters: {freezed_parameters}")
# selecting parameters, which requires gradients and initializing optimizer
model_parameters = filter(lambda parameter: parameter.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params=model_parameters, lr=2e-5, weight_decay=0.0)
3.自动混合精度AMP
自动混合精度(AMP)是另一种在不损失最终质量的情况下减少显存消耗和训练时间的方法,该方法由NVIDIA和百度研究人员在2017年的"Mixed Precision Training"论文中提出。思想是使用较低的精度将模型的梯度和参数保留在内存中,即不使用全精度(float32),而是使用半精度(例如float16)将张量保存在内存中。然而,当以较低精度计算梯度时,某些值可能太小,以至于被视为零,这种现象被称为“溢出”。为了防止“溢出”,原始论文的作者提出了一种梯度缩放方法。
PyTorch从1.6的版本开始提供了一个包:torch.cuda.amp
,具有使用自动混合精度所需的功能(从降低精度到梯度缩放),自动混合精度作为上下文管理器实现,因此可以随时随地的插入到训练和推理脚本中。
from torch.cuda.amp import autocast, GradScaler #amp接口
scaler = GradScaler()
for step, batch in enumerate(loader, 1):
# prepare inputs and targets for the model and loss function respectively.
# forward pass with `autocast` context manager
with autocast(enabled=True):
outputs = model(inputs)
# computing loss
loss = loss_fn(outputs, targets)
# scale gradint and perform backward pass
scaler.scale(loss).backward() #对损失缩放后再反向传播
# before gradient clipping the optimizer parameters must be unscaled.
scaler.unscale_(optimizer) #梯度裁剪前参数必须复原
# perform optimization step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
4. 8-bit优化器
8-bit优化的思想类似于自动混合精度。除了模型的参数和梯度保持较低的精度, 该方法还让优化器保持在低精度状态。原始论文 "8-bit Optimizers via Block-wise Quantization"有详细的描述,8-bit优化能显著降低内存利用率和稍微加快训练。此外, 作者研究了不同hyper parameter依然效果稳定。作者还提供了一个开源库 ,称为 bitsandbytes.
代码
!pip install -q bitsandbytes-cuda110
def set_embedding_parameters_bits(embeddings_path, optim_bits=32):
"""
https://github.com/huggingface/transformers/issues/14819#issuecomment-1003427930
"""
embedding_types = ("word", "position", "token_type")
for embedding_type in embedding_types:
attr_name = f"{embedding_type}_embeddings"
if hasattr(embeddings_path, attr_name):
bnb.optim.GlobalOptimManager.get_instance().register_module_override(
getattr(embeddings_path, attr_name), 'weight', {'optim_bits': optim_bits}
)
import bitsandbytes as bnb
# selecting parameters, which requires gradients
model_parameters = filter(lambda parameter: parameter.requires_grad, model.parameters())
# initializing optimizer
bnb_optimizer = bnb.optim.AdamW(params=model_parameters, lr=2e-5, weight_decay=0.0, optim_bits=8)
# bnb_optimizer = bnb.optim.AdamW8bit(params=model_parameters, lr=2e-5, weight_decay=0.0) # equivalent to the above line
# setting embeddings parameters
set_embedding_parameters_bits(embeddings_path=model.embeddings)
print(f"8-bit Optimizer:\n\n{bnb_optimizer}")
5. 梯度检查点
有时候,即使用了上面的几种方法,显存可能还是不够,尤其是在模型足够大的情况下。那么梯度检查点(Gradient Checkpointing)就是大招了,这个方法第一次在 "Training Deep Nets With Sublinear Memory Cost" ,作者表明梯度检查点可以显著降低显存利用率,从 O(n) 降低到 O(n) ,其中n是模型的层数。这种方法允许在单个GPU上训练大型模型,或者提供更多内存以增加批量大小,从而更好更快地收敛。梯度检查点背后的思想是在小数据块中计算梯度,同时在正向和反向传播过程中从内存中移除不必要的梯度,从而降低内存利用率,但是这种方法需要更多的计算步骤来再现整个反向传播图,其实就是一种用时间来换空间的方法。
PyTorch框架里也有梯度检查点的实现,函数:torch.utils.checkpoint.checkpoint
和torch.utils.checkpoint.checkpoint_sequential
这边引用一段torch官网对梯度检查点的介绍:
在前向传播中,函数用torch.no_grad() 方式进行,此时,不存储中间激活,而是存储输入元祖和函数参数;在反向传播中,先通过前向传播的方式再次计算函数,跟踪中间激活值,然后使用这些激活值计算梯度计。
此外,HuggingFace Transformers也支持梯度检查点。梯度检查点可以通过PreTrainedModel实例的gradient_checkpointing_enable方法执行,一行代码搞定!
from transformers import AutoConfig, AutoModel
# https://github.com/huggingface/transformers/issues/9919
from torch.utils.checkpoint import checkpoint
# initializing model
model_path = "microsoft/deberta-v3-base"
config = AutoConfig.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, config=config)
# gradient checkpointing
model.gradient_checkpointing_enable()
print(f"Gradient Checkpointing: {model.is_gradient_checkpointing}")
参考文献:
- Optimization approaches for Transformers | Kaggle
- Performance and Scalability: How To Fit a Bigger Model and Train It Faster
- Speeding up Transformer w/ Optimization Strategies
- Things you can try to speed up training speed and preventing memory shortage if you are using transformers.
- 8-bit Adam and other memory optimizations
- Fitting larger networks into memory.