11.pytorch lightning之ModelCheckpoint

ModelCheckpoint

callback回调就是在合适的时候调用相应的函数,比如在训练开始前,训练结束后,每个epoch前后等等。training_step(), validation_step(), test_step()等函数都是在合适的时候被调用的。
ModelCheckpoint是内置的一个callback,定义了checkpoint保存的方式,其内部在对应的step和epoch过程中实现了保存模型的逻辑,一个例子如下:

from lightning.pytorch.callbacks import ModelCheckpoint


# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,# 保存最好的10个checipoints
    monitor="val_loss",# 比较的指标为val_loss
    mode="min",# 越小越好
    dirpath="my/path/",# 保存路径
    filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",# 保存文件名格式。
)

需要注意的是,montior指定的指标必须在LightningModule类中指定。

from lightning.pytorch.callbacks import ModelCheckpoint


class LitAutoEncoder(LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)

        # 1. calculate loss
        loss = F.cross_entropy(y_hat, y)

        # 2. log val_loss
        # 记录loss,记录了才能比较
        self.log("val_loss", loss)


# 3. Init ModelCheckpoint callback, monitoring "val_loss"
# 创建一个ModelCheckpoint实例
checkpoint_callback = ModelCheckpoint(monitor="val_loss")

# 4. Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])

另外也可以使用YAML文件配置。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值