官方的将pytorch转换为fabric简单分为五个步骤:
步骤 1:
在训练代码的开头创建 Fabric
对象
from lightning.fabric import Fabric
fabric = Fabric()
步骤 2:
如果打算使用多个设备(例如多 GPU),就调用 launch()
fabric.launch()
步骤 3:
在每个模型和优化器对上调用 setup()
,在所有数据加载器上调用 setup_dataloaders()
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
步骤 4:
删除所有 .to
和 .cuda
调用,因为 Fabric
将自动处理
- model.to(device) # 删除
- batch.to(device) # 删除