Keras中如何利用回调函数Callback设置训练前先验证一遍验证集数据

【时间】2019.12.06

【题目】Keras中如何利用回调函数Callback设置训练前先验证一遍验证集数据

背景:

Keras中能够利用keras.callbacks.ModelCheckpoint()回调函数设置保存最佳权重的点,如:

  checkpoint = ModelCheckpoint(filepath=best_filepath,monitor="val_acc", verbose=1,
                              save_best_only=True,mode="max",save_weights_only=True) 

但是在第一个epoch时,监视指标为-np.inf (以val_acc为例),这样如果你加载已有weights后进行训练的第一个epoch一定会保存,这不符合初衷。一个解决办法是:训练前先验证一遍验证集数据。但是查看了keras,好像没有实现这个功能的函数。

最后查看keras.callbacks.ModelCheckpoint()的源代码,发现是通过监控checkpoint.best来决定是否保存,这样就可以这样实现:

1、使用Callback的自定义函数在第训练之前先对验证集验证一遍,获得监控指标,如val_acc

2、将监控指标赋值给checkpoint.best。

一、实现方法

使用Callback的自定义回调函数在第训练之前先对验证集验证一遍,获得监控指标,并赋值给checkpoint.best

自定义回调函数代码:

##EvaluateBeforeTrain
class EvaluateBeforeTrain(keras.callbacks.Callback):
    def __init__(self,checkpoint):
      super(Callback,self).__init__()
      self.checkpoint=checkpoint

    def on_epoch_begin(self,epoch,logs={}):
      #evaluate validation data
      if epoch==0:
        X=self.validation_data[0]
        Y=self.validation_data[1]
        result=self.model.evaluate(X,Y)
        loss=result[0]
        acc=result[1]
        self.checkpoint.best=acc
        print('first_val_acc is:',acc)

分析:

1.在自定义Callback回调函数时,使用self.validation_data可获得模型在fit()中传入的validation_data,使用self.model能够获得对应模型实例。

2、定义on_epoch_begin()函数可以让函数在每个epoch之前执行

PS:不使用on_train_begin()函数的原因是此函数获取不了self.validation_data。(不知原因)

3、传入的参数self.checkpoint=checkpoint是ModelCheckPoint的实例,目的是为了给checkpoint.best赋值

二、测试代码

from sklearn.preprocessing import LabelBinarizer
from keras.optimizers import  SGD
from keras.datasets import cifar10
from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import ModelCheckpoint,Callback

best_filepath=''
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
trainX = trainX.reshape((trainX.shape[0], 3072))
testX = testX.reshape((testX.shape[0], 3072))

# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)
##DNN
model = Sequential()
model.add(Dense(1024, input_shape=(3072,), activation="relu"))
model.add(Dense(512, activation="relu"))
model.add(Dense(10, activation="softmax"))

if os.path.exists(best_filepath):
    model.load_weights(best_filepath)
    print("have load weight")

checkpoint = ModelCheckpoint(filepath=best_filepath,monitor="val_acc", verbose=1,
                              save_best_only=True,mode="max",save_weights_only=True) 
evaluateBeforeTrain=EvaluateBeforeTrain(checkpoint)
callbacks_list = [checkpoint,evaluateBeforeTrain]
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
  H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=64,callbacks=callbacks_list,verbose=1)#callbacks=callbacks_list,
  model.save_weights(final_save_path)

运行结果:

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值