Tensorboard 中加入matplotlib.pyplot输出的图

查了不少网站,在这介绍一段在使用tensorflow.keras进行神经网络训练的时候,通过编写一个tf.python.keras.callbacks.Callback的子类,实现在每个或几个epoch结束后,在Tensorboard的Image中显示matplotlib.pyplot生成的图片的方法。

import io
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.summary import summary as tf_summary
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import numpy as np

class MyCallbacks(Callback):
    def __init__(self, logdir, period, val_data):
        super(MyCallbacks, self).__init__()
        self.logdir = logdir
        self.period = period
        self.last_rcd = 0
        self.writer = tf_summary.FileWriter(self.logdir)
        self.validation_data = val_data

    def gen_plot(self, y_predict):
        real_part = np.reshape(y_predict, [-1]) # vectorize y_predict
        imag_part = np.reshape(y_predict,[-1])
        plt.figure()
        plt.scatter(real_part, imag_part)
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        return buf


    def on_epoch_end(self, epoch, logs=None):
        self.last_rcd = self.last_rcd + 1
        if self.last_rcd >= self.period:
            self.last_rcd = 0
            y_predict = self.model.predict(self.validation_data, steps = 32)

            # Prepare the plot
            plot_buf = self.gen_plot(y_predict)

            # Convert PNG buffer to TF image
            image = tf.image.decode_png(plot_buf.getvalue(), channels=4)

            # Add the batch dimension
            image = tf.expand_dims(image, 0)

            # Add image summary           
            with tf.Session() as sess:
                # Run
                summary_op = tf.summary.image("plot", image)
                summary = sess.run(summary_op)
                # Write summary
                
                self.writer.add_summary(summary)
          
    def on_train_end(self, logs=None):
        self.writer.close()

在tf.keras下训练的时候如下调用:

callbacks = [
        # Write TensorBoard logs to `./logs` directory
        tf.keras.callbacks.TensorBoard(log_dir=log_dir, write_graph=True,write_grads=True, write_images = False),
        # Create checkpoint callback
        tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1, save_weights_only=True, period=50),
        ConstellationCallbacks(logdir = log_dir, period = 10, val_data = my_val_data)
    ]

model.fit(training_data, training_label, epochs=2000, batch_size=256, shuffle = True,
              callbacks=callbacks,validation_data=(my_val_data, my_val_label))

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值