'''
参考stackoverflow:
https://stackoverflow.com/questions/49127214/keras-how-to-output-learning-rate-onto-tensorboard
'''
############# 添加部分 ###############
class LearningRateLogger(tf.keras.callbacks.Callback):
def __init__(self):
super().__init__()
self._supports_tf_logs = True
def on_epoch_end(self, epoch, logs=None):
if logs is None or "learning_rate" in logs:
return
logs["learning_rate"] = self.model.optimizer.lr
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
"densenet_crnn_qunt_{epoch}.h5",
monitor='val_accuracy',
save_best_only=False,
verbose=2,
save_weights_only=False,
period=1),
tf.keras.callbacks.EarlyStopping(monitor='val_loss',
patience=5,
mode='auto',
restore_best_weights=True),
######### 添加部分###########
LearningRateLogger(),
tf.keras.callbacks.TensorBoard(log_dir='logs/',histogram_freq=0),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00000001,verbose=2),
]
quant_model.fit(train_ds,
epochs=args.epochs,
validation_data=val_ds,
callbacks=callbacks,)
在callbacks中添加LearningRateLogger即可:
在命令行输入:
tensorboard --logdir=./ --host=127.0.0.1
tensorboard可视化中也展示了学习率的变化: