Callback--ModelCheckpoint

ModelCheckpoint

参数详解

参数名称含义默认值
dirpathckpt文件保存路径None(使用Trainer的default_root_dirweights_save_path,如果Trainer使用了logger,那么path将会包含logger的名字和版本)
filenameckpt文件名(支持自动填充)None({epoch}-{step})
monitor要监视的指标(log()记录的指标)None(保存最后一次epoch训练的结果)
save_last是否保存最后一次epoch训练的结果(last.ckpt)None(False)
save_top_k保存前k个最佳模型,k=-1的保存所有模型,k=0将不会保存模型,文件名后面会追加版本号,从v1开始1
save_weights_only仅保存模型权重False
mode监视指标的最大值还是最小值.对于loss应使用min,对于accuracy应使用max‘min’
auto_insert_metric_name是否自动向文件名中插入monitior的值True

使用注意

  • 如果设置auto_insert_metric_nameFalse,对于包含**/**的指标名,将会创建额外的文件夹
  • checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"可以改变save_last的默认文件名

示范

class MyModel(pl.LightningModule):
    def __init__(self, in_dim, out_dim, lr=0.00005):
        super().__init__()
        self.net = nn.Linear(in_dim, out_dim)
        self.loss_fn = nn.MSELoss()
        self.lr = lr

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = self.loss_fn(y_hat, y)
        acc = (y_hat.argmax(1) == y).type(torch.float).sum()
        self.log('Train_accuracy', acc,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = self.loss_fn(y_hat, y)
        self.log('Val_loss', loss,
                 on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return optim.SGD(self.net.parameters(), self.lr)

# 监视训练准确率,注意这里性能指标名字要与self.log中保持一致
ckpt_callback_train_acc = ModelCheckpoint(
    monitor='Train_accuracy', dirpath='my/path',
    filename='epoch{epoch:02d}-val_acc{Val_accuracy:.2f}',
    auto_insert_metric_name=False,
    save_last=True, save_weights_only=True, mode='max')
# 监视验证损失,注意这里性能指标名字要与self.log中保持一致
ckpt_callback_val_loss = ModelCheckpoint(
    monitor='Val_loss', dirpath='my/path', mode='min')

# 使用一个callback,可以不写成列表形式
trainer = pl.Trainer(callbacks=ckpt_callback_train_acc)
# 使用多个callbacks,传递callback列表
trainer = pl.Trainer(callbacks=[ckpt_callback_train_acc, ckpt_callback_val_loss])
# ....训练

self.log()用法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值