文章目录
TensorFlow中常见的CallBack
Tensorboard
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir)
model.fit(train_batches,
epochs=10,
validation_data=validation_batches,
callbacks=[tensorboard_callback])
我们可以使用tensorboard画出我们想要的关于模型的图案,具体tensorboard的使用是一个很有趣的过程。
Checkpoint
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_batches,
epochs=5,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.h5', verbose=1),
])
我们可以使用checkpoint按照一定的要求对模型进行保存(例如,按照一定的频率,时间)
Earlystoping
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_batches,
epochs=50,
validation_data=validation_batches,
verbose=2,
callbacks=[EarlyStopping(
patience=3,
min_delta=0.05,
baseline=0.8,
mode='min',
monitor='val_loss',
restore_best_weights=True,
verbose=1)
])
当我们发现模型的方差在增大,val_loss上升,模型泛化能力变差,我们可以提前终止训练。
CSVLogger
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
csv_file = 'training.csv'
model.fit(train_batches,
epochs=5,
validation_data=validation_batches,
callbacks=[CSVLogger(csv_file)
])
将训练中的信息按照CSV文件格式给出。
LearningRateScheduler
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
def step_decay(epoch):
initial_lr = 0.01
drop = 0.5
epochs_drop = 1
lr = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop))
return lr
model.fit(train_batches,
epochs=5,
validation_data=validation_batches,
callbacks=[LearningRateScheduler(step_decay, verbose=1),
TensorBoard(log_dir='./log_dir')])
在训练中动态修改学习率,使得模型能够更快收敛。
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_batches,
epochs=50,
validation_data=validation_batches,
callbacks=[ReduceLROnPlateau(monitor='val_loss',
factor=0.2, verbose=1,
patience=1, min_lr=0.001),
TensorBoard(log_dir='./log_dir')])
和上面类似,只不过该方法只有当遇到瓶颈的时候才修改学习率
定义CallBack类
我们可以从Callback类继承,从而定义我们自己的类。
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback
class MyCallback(Callback):
def __init__(self, loss_threshold=0.01):
super(MyCallback, self).__init__()
self.loss_threshold = loss_threshold
def on_train_begin(self, logs=None):
print("training begin")
def on_epoch_end(self, epoch, logs=None):
if logs['train_loss'] < self.loss_threshold:
self.model.stop_training = True
print('loss is enough')
在Callback基类中定义了很多函数,我们都可以重载,例如 o n _ t r a i n i n g _ b e g i n , o n _ e p o c h _ e n d on\_training\_begin,on\_epoch\_end on_training_begin,on_epoch_end等等。