ModelCheckpoint
常用参数
参数名称 | 含义 | 默认值 |
---|
dirpath | ckpt文件保存路径 | None(使用Trainer的default_root_dir 或weights_save_path ,如果Trainer使用了logger ,那么path将会包含logger的名字和版本) |
filename | ckpt文件名(支持自动填充) | None ({epoch}-{step} ) |
monitor | 要监视的指标(log() 记录的指标) | None (保存最后一次epoch训练的结果) |
save_last | 是否保存最后一次epoch训练的结果(last.ckpt ) | None(False) |
save_top_k | 保存前k个最佳模型,k=-1的保存所有模型,k=0将不会保存模型,文件名后面会追加版本号,从v1开始 | 1 |
save_weights_only | 仅保存模型权重 | False |
mode | 监视指标的最大值还是最小值。对于loss应使用min ,对于accuracy应使用max | min |
auto_insert_metric_name | 是否自动向文件名中插入monitior 的值 | True |
save_on_train_epoch_end | 是否在训练epoch结束是执行检查 | None |
常用属性
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
,设置save_last
文件名FILE_EXTENSION = ".ckpt"
,文件扩展名STARTING_VERSION = 1
,开始版本号
使用注意
- 如果设置
auto_insert_metric_name
为False
,对于包含/
的指标名,将会创建额外的文件夹 checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
可以改变save_last
的默认文件名
示范
class MyModel(pl.LightningModule):
def __init__(self, in_dim, out_dim, lr=0.00005):
super().__init__()
self.net = nn.Linear(in_dim, out_dim)
self.loss_fn = nn.MSELoss()
self.lr = lr
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_idx):
X, y = batch
y_hat = self(X)
loss = self.loss_fn(y_hat, y)
acc = (y_hat.argmax(1) == y).type(torch.float).sum()
self.log('Train_accuracy', acc,
on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
X, y = batch
y_hat = self(X)
loss = self.loss_fn(y_hat, y)
self.log('Val_loss', loss,
on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return optim.SGD(self.net.parameters(), self.lr)
ckpt_callback_train_acc = ModelCheckpoint(
monitor='Train_accuracy', dirpath='my/path',
filename='epoch{epoch:02d}-val_acc{Val_accuracy:.2f}',
auto_insert_metric_name=False,
save_last=True, save_weights_only=True, mode='max')
ckpt_callback_val_loss = ModelCheckpoint(
monitor='Val_loss', dirpath='my/path', mode='min')
trainer = pl.Trainer(callbacks=ckpt_callback_train_acc)
trainer = pl.Trainer(callbacks=[ckpt_callback_train_acc, ckpt_callback_val_loss])
self.log()用法