ModelCheckpoint
callback回调就是在合适的时候调用相应的函数,比如在训练开始前,训练结束后,每个epoch前后等等。training_step(), validation_step(), test_step()等函数都是在合适的时候被调用的。
ModelCheckpoint是内置的一个callback,定义了checkpoint保存的方式,其内部在对应的step和epoch过程中实现了保存模型的逻辑,一个例子如下:
from lightning.pytorch.callbacks import ModelCheckpoint
# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
save_top_k=10,# 保存最好的10个checipoints
monitor="val_loss",# 比较的指标为val_loss
mode="min",# 越小越好
dirpath="my/path/",# 保存路径
filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",# 保存文件名格式。
)
需要注意的是,montior指定的指标必须在LightningModule类中指定。
from lightning.pytorch.callbacks import ModelCheckpoint
class LitAutoEncoder(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
# 1. calculate loss
loss = F.cross_entropy(y_hat, y)
# 2. log val_loss
# 记录loss,记录了才能比较
self.log("val_loss", loss)
# 3. Init ModelCheckpoint callback, monitoring "val_loss"
# 创建一个ModelCheckpoint实例
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
# 4. Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])
另外也可以使用YAML文件配置。