ModelCheckpoint和自定义回调函数(on_epoch_end)

ModelCheckpoint和自定义回调函数(on_epoch_end)区别

根据keras中文文档ModelCheckpoint的作用是:在每个训练期之后保存模型。

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。

例如:如果 filepath 是 weights.{epoch:02d}-{val_loss:.2f}.hdf5, 那么模型被保存的的文件名就会有训练轮数和验证损失。

参数

filepath: 字符串,保存模型的路径。
monitor: 被监测的数据。
verbose: 详细信息模式,0 或者 1 。
save_best_only: 如果 save_best_only=True, 被监测数据的最佳模型就不会被覆盖。
mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
period: 每个检查点之间的间隔(训练轮数)。
(来自keras中文文档)传送门

可能会有人疑惑那下面这段代码是做什么的

def get_lowest_loss():
    import re
    pattern = 'model.(?P<epoch>\d+)-(?P<val_acc>[0-9]*\.?[0-9]*).hdf5'
    p = re.compile(pattern)
    loss = [float(p.match(f).groups()[1]) for f in os.listdir('models/') if p.match(f)]
    if len(loss) == 0:
        import sys
        return sys.float_info.min
    else:
        return np.min(loss)
        
 class MyCbk(keras.callbacks.Callback):
       def __init__(self, model):
           keras.callbacks.Callback.__init__(self)
           self.model_to_save = model
#这段代码就是自定义回调函数的ModelCheckpoint的功能,on_epoch_end继承自keras.callbacks.Callback类
#的方法,即每次epoch结束时调用里面的代码   
       def on_epoch_end(self, epoch, logs=None):
            fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5'
            highest_acc = get_lowest_loss()
            if float(logs['val_loss']) > highest_acc:
                self.model_to_save.save(fmt % (epoch, logs['val_loss']))

上面代码的功能可以用下面这句话替代,两者的功能都是在一次epoch结束时保存val_loss最小的模型参数。

    model_names = checkpoint_models_path + 'model.{epoch:02d}-{val_loss:.4f}.hdf5'
    model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
在 Keras 中使用 ModelCheckpoint 回调函数可以在每个 epoch 结束时保存模型的权重。下面是一个简单的示例,展示了如何定义 ModelCheckpoint 回调函数: ```python from keras.callbacks import ModelCheckpoint # 定义 ModelCheckpoint 回调函数 checkpoint = ModelCheckpoint(filepath='model_weights.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min') # 在模型训练时使用 ModelCheckpoint 回调函数 model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10, batch_size=32, callbacks=[checkpoint]) ``` 在上面的示例中,我们首先从 Keras 中导入 ModelCheckpoint 回调函数。然后,我们定义了一个名为 `checkpoint` 的 ModelCheckpoint 对象。这个对象有几个参数: - `filepath`:保存模型权重的路径。可以是绝对路径或相对路径。 - `monitor`:监视的指标。在这个例子中,我们使用验证集上的损失函数作为监视指标。 - `verbose`:日志输出级别。在这个例子中,我们将它设置为 1,这样每次保存模型权重时,都会输出一条消息。 - `save_best_only`:是否只保存最佳模型权重。在这个例子中,我们将它设置为 True,这样只有当监视指标有所改善时,才会保存模型权重。 - `mode`:监视指标的模式。在这个例子中,我们将它设置为 'min',表示我们希望监视的指标越小越好。 最后,我们将 ModelCheckpoint 对象传递给模型的 `fit` 方法的 `callbacks` 参数中,这样在模型训练时,每当一个 epoch 结束时,ModelCheckpoint 回调函数就会自动保存模型的权重到指定的文件中。 需要注意的是,如果您要使用 ModelCheckpoint 回调函数,请确保您的代码中已经定义了一个 Keras 模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值