目录
二、移动优化器 (optimizer(s)) 和调度器 (schedulers)
六、移除任何 .cuda() 或 to.device() 调用
项目地址:https://github.com/PyTorchLightning/pytorch-lightning
为使 PyTorch 代码能够用 Lightning 运行,本文将简要介绍如何 将 PyTorch 组织到 Lightning 中。
一、移动运算代码 (computational code)
将模型架构和前向传播移动到 LightningModule:
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
二、移动优化器 (optimizer(s)) 和调度器 (schedulers)
将优化器移至 c