- MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
- 分别基于 Pytorch 内置的优化器和 MMEngine 的
优化器封装(OptimWrapper)
进行单精度训练
、混合精度训练
和梯度累加
,对比二者实现上的区别。
一、 基于 Pytorch 的 SGD 优化器实现单精度训练
import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F
inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
for input, target in zip(inputs, targets):
output = model(input)
loss = F.l1_loss(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
二、 使用 MMEngine 的优化器封装实现单精度训练
from mmengine.optim import OptimWrapper
optim_wrapper = OptimWrapper(optimizer=optimizer)
for input, target in zip(inputs, targets):
output = model(input)
loss = F.l1_loss(output, target)
optim_wrapper.update_params(loss)
优化器封装的 update_params
实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
三、 基于 Pytorch 的 SGD 优化器实现混合精度训练
- 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
- 内存占用更少
- 计算更快
from torch.cuda.amp import autocast
model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10
for input, target in zip(inputs, targets):
with autocast():
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
loss.backward()
optimizer.step()
optimizer.zero_grad()
四、 基于 MMEngine 的 优化器封装实现混合精度训练
from mmengine.optim import AmpOptimWrapper
optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
for input, target in zip(inputs, targets):
with optim_wrapper.optim_context(model):
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
optim_wrapper.update_params(loss)
- 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍
五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加
for idx, (input, target) in enumerate(zip(inputs, targets)):
with autocast():
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
loss.backward()
if idx % 2 == 0:
optimizer.step()
optimizer.zero_grad()
六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加
optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)
for input, target in zip(inputs, targets):
with optim_wrapper.optim_context(model):
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
optim_wrapper.update_params(loss)
只需要配置 accumulative_counts
参数,并调用 update_params
接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper
上下文,可以避免梯度累加阶段不必要的梯度同步。
七、 获取学习率/动量
优化器封装提供了 get_lr
和 get_momentum
接口用于获取优化器的一个参数组的学习率:
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper
model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)
print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr']) # 0.01
print(optimizer.param_groups[0]['momentum']) # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr()) # {'lr': [0.01]}
print(optim_wrapper.get_momentum()) # {'momentum': [0]}
八、 导出/加载状态字典
优化器封装和优化器一样,提供了 state_dict
和 load_state_dict
接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper
model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)
# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)
# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)
九、 使用多个优化器
OptimWrapperDict
的核心功能是支持批量
导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。
from torch.optim import SGD
import torch.nn as nn
from mmengine.optim import OptimWrapper, OptimWrapperDict
# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)
# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)
# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)
# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)
print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr()) # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]}