import tensorflow as tf
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('loss')<0.4):
print("\nLoss is low so cancelling training!")
self.model.stop_training = True
这是让python回调的代码。它被实现为一个单独的类,但可以和其他代码串联起来,它不需要在一个单独的文件中。
在这里,定义on_epoch_end()函数,在迭代结束后,调用回调。它也发送一个包含很多关于训练当前状态的信息—日志对象,也就是
if(logs.get(‘loss’)<0.4):
self.model.stop_training = True
当前损失率loss在日志中可用,所以我们可以查询一定数量。在这里,loss低于0.4就取消训练。
现在有了回调函数,
callbacks = myCallback()#实例化刚刚创建的类,回调类本身
mnist = tf.keras.datasets.fashion_mnist
(training_images, training_labels),(test_images, test_labels) = mnist.load_data()
training_images = training_images / 255.0
test_images = test_images / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy')
model.fit(training_images, training_labels, epochs=5, callbacks = [callbacks])#最后一个参数是使用回调参数作为训练循环的一部分,并传递它的类实例
nodel.fit()函数是让其训练的。
运行之后,发现:我们设置了5个周期并且在两个训练周期后就结束了训练,因为损失已经低于0.4了。