依赖:
pip install lightning
插拔改动:
from lightning.fabric import Fabric
#...
# 实例化
fabric = Fabric(accelerator='cuda')
# 混精度用这个,加速明显
#fabric = Fabric(accelerator="cuda", precision="bf16-mixed")
fabric.launch()
#...
# 插拔接入
model, optimizer = fabric.setup(model, optimizer)
train_dataloader = fabric.setup_dataloaders(train_dataloader)
#...
def train(num_epochs, model, optimizer, train_loader, val_loader, fabric):
for epoch in range(num_epochs):
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device)
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
model.train()
logits = model(features)
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
fabric.backward(loss) # 插拔接入,原反向传播:loss.backward()
optimizer.step()
#...
参考文献
CVPR 2023 Talk:Scaling PyTorch Model Training With Minimal Code Changes