Callback
1. 先看一下callback类源码
class Callback(object):
def __init__(self):
self.validation_data = None
self.model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
我们在调用是可以这样写(保存最优模型)
# 回调类1
class Evaluator1(keras.callbacks.Callback):
def __init__(self):
self.best_val_acc = 0.
# 每迭代一次,调用一次
def on_epoch_end(self, epoch, logs=None):
val_acc = evaluate(valid_generator)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
model.save_weights(r'data/best_model.weights')
test_acc = evaluate(test_generator)
print(
u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
(val_acc, self.best_val_acc, test_acc)
)
2. 自带的回调方法
2.1 ModelCheckpoint
在每个训练期之后保存模型。相当于在on_epoch_end()中保存了模型
参数:
2.2 EarlyStopping
当被监测的数量不再提升,则停止训练。
参数:
2.3 TensorBoard
学习中…