一、该项目的演示代码如下所示
import tensorflow as tf
import os
import datetime
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()
train_image = tf.expand_dims(train_image, -1)
train_image = tf.cast(train_image / 255, tf.float32)
train_label = tf.cast(train_label, tf.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
train_dataset = train_dataset.repeat().shuffle(60000).batch(128)
test_image = tf.expand_dims(test_image, -1)
test_image = tf.cast(test_image / 255, tf.float32)
test_label = tf.cast(test_label, tf.int64)
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_label))
test_dataset = test_dataset.repeat().batch(128)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.GlobalMaxPooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
log_dir = os.path.join('logs', datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
file_writer = tf.summary.create_file_writer(log_dir + '/lr')
file_writer.set_as_default()
def lr_sche(epoch):
learning_rate = 0.2
if epoch > 5:
learning_rate = 0.02
if epoch > 10:
learning_rate = 0.01
if epoch > 20:
learning_rate = 0.005
tf.summary.scalar('learning_rate', data=learning_rate, step=epoch)
return learning_rate
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_sche)
model.fit(train_dataset,
epochs=100,
steps_per_epoch=60000 // 128,
validation_data=test_dataset,
validation_steps=10000 // 128,
callbacks=[tensorboard_callback, lr_callback])
model.save(r'model_data/tensor_board.h5')
二、上述代码中构造了两个回调函数,在训练过程中自动的进行调用,在训练结束后输入以下命令即可进行相关数据变化的查看。
tensorboard --logdir 相关日志保存的位置