import tensorflow as tf
from tensorflow.keras import layers
# 定义一个myCallback类,继承了tensorflow中自带的Callback类
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if(logs.get('accuracy') > 0.96):
print("\n 已经到达 96% 的训练精度!")
self.model.stop_training = True
# 实例化一个myCallback 对象 callbacks
callback = myCallback()
mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
training_images = training_images / 255.0
test_images = test_images / 255.0
model = tf.keras.models.Sequential([layers.Flatten(),
layers.Dense(128, activation="relu"),
layers.Dense(10, activation="softmax")
])
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=['accuracy'])
model.fit(training_images, training_labels, epochs=10, callbacks=[callback])
model.evaluate(test_images, test_labels)
tensorflow2.x 使用callback方法停止模型训练(mnist手写数字)
最新推荐文章于 2024-04-23 01:33:40 发布