pytorch lightning--ModelCheckpoint

本笔记主要以pytorch lightning中的ModelCheckpoint接口解析pytorch lightning中模型的保存方式

ModelCheckpoint

该类通过监控设置的metric定期保存模型,LightningModule 中使用 log() 或 log_dict() 记录的每个metric都是监控对象的候选者;更多的信息可以进入此链接浏览。训练完成后,在日志中使用 best_model_path 检索最佳checkpoint的路径,使用 best_model_score 检索其分数

pytorch_lightning.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode=‘min’, auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None)

  • dirpath(Union[str, Path, None])–保存模型文件的路径;默认情况下,dirpath 为 None 并将在运行时设置为 Trainer 的 default_root_dir 或 weights_save_path 参数指定的位置,如果 Trainer 使用logger,则路径还将包含logger名称和版本
# custom path,自定义路径
# saves a file like: my/path/epoch=0-step=10.ckpt,文件名会自动由epoch和step构成
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
  • filename(Optinal[str])–checkpoint文件名;可以包含要自动填充的命名格式选项;默认情况下,文件名是 None 并将设置为 ‘{epoch}-{step}’
# save any arbitrary metrics like `val_loss`, etc. in name,将感兴趣的metrics保存在文件名中
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
checkpoint_callback = ModelCheckpoint(
...     dirpath='my/path',
...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
  • monitor(Optional[str])–被监控的值;默认情况下,它是 None ,它只为最后一个 epoch 保存一个检查点
  • verbose(bool)–详细模式;默认为False
  • save_last(Optional[bool])–为True时,无论训练过程中是否有checkpoint文件保存,在训练结束后都会保存一个名为last.ckpt的checkpoint文件。这允许以确定的方式访问最新的checkpoint文件;默认为None
  • save_top_k(int)–如果save_top_k=k,基于设置的检测对象,其最好的k个模型会被保存;如果save_top_k=0,没有模型会被保存;如果save_top_k=-1,所有模型都会保存。注意,监视器是每every_n_epochs个epoch会被检查一次。如果save_top_k>2并且在一次epoch多次调用回调,保存文件的名称将附加以 v1 开头的版本计数
  • mode(str)–{min,max}其中之一。如果save_top_k> != 0,覆盖当前保存文件的决定是基于监控值的最大化或最小化做出的;如对于val_acc,应该越大越好,要设置max,而对于val_loss,应该越小越好,要设置为min
  • auto_insert_metric_name(bool)–为True时,保存的checkpoint文件名会包含metric名称。例如,带有 epoch 1 和 acc 1.12 的 filename=‘checkpoint_{epoch:02d}-{acc:02.0f} 将解析为 checkpoint_epoch=01-acc=01.ckpt。当metric名称中含有’/'时,设置为False较好,否则会导致生成额外的文件夹,如filename=‘epoch={epoch}-step={step}-val_acc={val/acc:.2f}’, auto_insert_metric_name=False
  • every_n_train_steps(Optional[int])–checkpoint文件两次保存之间的step数。如果every_n_train_steps = None或every_n_train_steps = 0,在训练过程中不保存checkpoint文件。该值必须是None或非负值,并且 其必须与train_time_interval 和 every_n_epochs 互斥
  • train_time_interval(Optional[timedelta])–以指定的时间间隔监控检checkpoints;出于实际目的,不能小于处理单个训练批次所需的时间;不能保证在指定的确切时间执行,但应该接近。并且 其必须与every_n_train_steps 和 every_n_epochs 必须互斥
  • every_n_epochs(Optional[int])–checkpoints文件两次保存之间的epoch数;必须是None或非负值;将其设置为every_n_epochs=0可以禁止save_top_k;此参数不影响 save_last=True 检查点的保存;如果every_n_epochs,every_n_train_steps和train_time_interval都为None,将在每个epoch结束时保存一个checkpoint文件(相当于every_n_epochs=1)如果every_n_epochs=None,且every_n_train_steps != None或train_time_interval != None时,在每个epoch结束时保存保存失效,(相当于every_n_epochs=0);every_n_epochs必须与every_n_train_steps 和 train_time_interval 互斥
  • save_on_train_epoch_end(Optional[bool])–是否在训练周期结束时运行检查点。如果是 False,则检查在验证结束时运行

简单案例代码

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
...     auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path  # 直接获取最好的模型保存的路径

同时保存和恢复多个checkpoint的回调是支持的,可浏览官方文档学习使用

模型手动保存

model = Pytorch_Lightning_Model(args)
train.fit(model)
train.save_checkpoint(example.ckpt)
  • 36
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值