TensorFlow的中阶API主要包括:
数据管道(tf.data)
特征列(tf.feature_column)
激活函数(tf.nn)
模型层(tf.keras.layers)
损失函数(tf.keras.losses)
评估指标(tf.keras.metrics)
优化器(tf.keras.optimizers)
回调函数(tf.keras.callbacks)
如果把模型比作一个房子,那么中阶API就是【模型之墙】。
本篇我们介绍回调函数。
一,回调函数概述
tf.keras的回调函数实际上是一个类,一般是在model.fit时作为参数指定,用于控制在训练过程开始或者在训练过程结束,在每个epoch训练开始或者训练结束,在每个batch训练开始或者训练结束时执行一些操作,例如收集一些日志信息,改变学习率等超参数,提前终止训练过程等等。
同样地,针对model.evaluate或者model.predict也可以指定callbacks参数,用于控制在评估或预测开始或者结束时,在每个batch开始或者结束时执行一些操作,但这种用法相对少见。
大部分时候,keras.callbacks子模块中定义的回调函数类已经足够使用了,如果有特定的需要,我们也可以通过对keras.callbacks.Callbacks实施子类化构造自定义的回调函数。
所有回调函数都继承至 keras.callbacks.Callbacks基类,拥有params和model这两个属性。
其中params 是一个dict,记录了 training parameters (eg. verbosit