前言
太坑了。。查了很多资料,终于解决了。
方法
首先,根据官方API的参数列表得知,在tf.keras.callbacks.TensorBoard
中修改update_freq
参数为batch
或一个整数
update_freq='batch'
# update_freq=10
# 如果改为使用一个整数N的话,过N批次后更新一次数据,
# 这样可以避免由于更新过于频繁而降低网络训练速度
然后,根据这则帖子,由于TensorFlow 2.3
做了一个优化,导致上面的方法在这里不管用。
解决方法是,除了TensorBoard
之类的callback以外,再添加一个LambdaCallback
,具体代码如下:
def batchOutput(batch, logs):
tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
return batch
batch_log_callback = callbacks.LambdaCallback(
on_batch_end=batchOutput)
于是终于成功
示例代码
改完后,我的训练部分的完整代码是这样的:
def train_model(save:bool=True):
# load and compile model
model = create_model()
model.compile(
loss='mean_squared_error',
optimizer='adam',
metrics=['accuracy'])
# prepare tensorflow dashboard
logdir = os.path.join(
'logs',
datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = callbacks.TensorBoard(
logdir,
histogram_freq=1,
write_images=True,
update_freq=10, # 查看批次级别的数据变化,需要结合LambdaCallback
embeddings_freq=1,
profile_batch=1)
# 实现查看批次级别数据变化
def batchOutput(batch, logs):
tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
return batch
batch_log_callback = callbacks.LambdaCallback(
on_batch_end=batchOutput)
# prepare early stop
early_stop_callback = callbacks.EarlyStopping(
monitor='val_loss',
patience=0,
restore_best_weights=True)
# train model
epochs_num = 4
model.fit(x=X,
y=X,
epochs=epochs_num,
batch_size=64,
validation_data=(X_eval, X_eval),
verbose=1, # 0:silent, 1:progress bar, 2:one line per epoch
callbacks=[tensorboard_callback,
batch_log_callback,
early_stop_callback])
# save model
if save:
MODEL_FOLDER = '/content/drive/MyDrive/A-Million-Headlines/pretrained'
model_name = 'AutoEncoder-model-{}-epochs-{}.h5'.format(epochs_num, int(time.time()))
joblib.dump(model, os.path.join(MODEL_FOLDER, model_name))
return model