Pytorch-Lightning中模型保存与加载
保存
自动保存
from lightning.pytorch.callbacks import ModelCheckpoint
class LitAutoEncoder(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
trainer = Trainer(callbacks=[checkpoint_callback])
手动保存
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
加载
使用load_from_checkpoint()
model = MyLightingModule.load_from_checkpoint(PATH)
使用Trainer中恢复
model = LitModel()
trainer = Trainer()
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
超参数保存与覆盖
class LitModel(LightningModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.save_hyperparameters()
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
LitModel(in_dim=32, out_dim=10)
model = LitModel.load_from_checkpoint(PATH)
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)