查了不少网站,在这介绍一段在使用tensorflow.keras进行神经网络训练的时候,通过编写一个tf.python.keras.callbacks.Callback的子类,实现在每个或几个epoch结束后,在Tensorboard的Image中显示matplotlib.pyplot生成的图片的方法。
import io
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.summary import summary as tf_summary
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import numpy as np
class MyCallbacks(Callback):
def __init__(self, logdir, period, val_data):
super(MyCallbacks, self).__init__()
self.logdir = logdir
self.period = period
self.last_rcd = 0
self.writer = tf_summary.FileWriter(self.logdir)
self.validation_data = val_data
def gen_plot(self, y_predict):
real_part = np.reshape(y_predict, [-1]) # vectorize y_predict
imag_part = np.reshape(y_predict,[-1])
plt.figure()
plt.scatter(real_part, imag_part)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return buf
def on_epoch_end(self, epoch, logs=None):
self.last_rcd = self.last_rcd + 1
if self.last_rcd >= self.period:
self.last_rcd = 0
y_predict = self.model.predict(self.validation_data, steps = 32)
# Prepare the plot
plot_buf = self.gen_plot(y_predict)
# Convert PNG buffer to TF image
image = tf.image.decode_png(plot_buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
# Add image summary
with tf.Session() as sess:
# Run
summary_op = tf.summary.image("plot", image)
summary = sess.run(summary_op)
# Write summary
self.writer.add_summary(summary)
def on_train_end(self, logs=None):
self.writer.close()
在tf.keras下训练的时候如下调用:
callbacks = [
# Write TensorBoard logs to `./logs` directory
tf.keras.callbacks.TensorBoard(log_dir=log_dir, write_graph=True,write_grads=True, write_images = False),
# Create checkpoint callback
tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1, save_weights_only=True, period=50),
ConstellationCallbacks(logdir = log_dir, period = 10, val_data = my_val_data)
]
model.fit(training_data, training_label, epochs=2000, batch_size=256, shuffle = True,
callbacks=callbacks,validation_data=(my_val_data, my_val_label))