Tensorflow2自定义回调

在进行数据拟合的时候,需要输出信息。自定义回调函数可以输出你想要的消息。

先创建网络结构

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

# Define the Keras model to add callbacks to
def get_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
  return model

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

接下来创建回调类

import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

这个类需要先继承tf.keras.callbacks.Callback。然后重写函数on_train_batch_begin,on_train_batch_end,on_test_batch_begin,on_test_batch_end。on_train_batch_begin是对每次batch训练开始时输出数据,on_train_batch_end是指结束时;on_test_batch_begin是指测试开始时,on_test_batch_end是指结束。

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[MyCustomCallback()])

训练时的输出

Training: batch 0 begins at 02:40:32.065132
Training: batch 0 ends at 02:40:32.559227
Training: batch 1 begins at 02:40:32.559599
Training: batch 1 ends at 02:40:32.562399
Training: batch 2 begins at 02:40:32.562614
Training: batch 2 ends at 02:40:32.564616
Training: batch 3 begins at 02:40:32.564818
Training: batch 3 ends at 02:40:32.566912
Training: batch 4 begins at 02:40:32.567254
Training: batch 4 ends at 02:40:32.569301

看下测试时的输出:

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
          callbacks=[MyCustomCallback()])
Evaluating: batch 0 begins at 02:40:32.618486
Evaluating: batch 0 ends at 02:40:32.668369
Evaluating: batch 1 begins at 02:40:32.668636
Evaluating: batch 1 ends at 02:40:32.670351
Evaluating: batch 2 begins at 02:40:32.670740
Evaluating: batch 2 ends at 02:40:32.672338
Evaluating: batch 3 begins at 02:40:32.672556
Evaluating: batch 3 ends at 02:40:32.674355
Evaluating: batch 4 begins at 02:40:32.674709
Evaluating: batch 4 ends at 02:40:32.676414

除了这四种函数可以重写之外,还有

on_(train|test|predict)_begin(self, logs=None)
on_(train|test|predict)_end(self, logs=None)
on_(train|test|predict)_batch_begin(self, batch, logs=None)
on_(train|test|predict)_batch_end(self, batch, logs=None)

logs是个字典里面有各类信息,除此之外:

on_epoch_begin(self, epoch, logs=None)
on_epoch_end(self, epoch, logs=None)

例子如下:

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):

  def on_train_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_test_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_epoch_end(self, epoch, logs=None):
    print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=3,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is   27.43.
For batch 1, loss is 1166.57.
For batch 2, loss is   26.95.
For batch 3, loss is    7.53.
For batch 4, loss is    9.72.
The average loss for epoch 0 is  247.64 and mean absolute error is    9.01.
For batch 0, loss is    7.49.
For batch 1, loss is    6.87.
For batch 2, loss is    4.39.
For batch 3, loss is    5.24.
For batch 4, loss is    4.75.
The average loss for epoch 1 is    5.75 and mean absolute error is    1.94.

参考文献:
https://tensorflow.google.cn/guide/keras/custom_callback

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值