这里的回调函数介绍得并不详细,只记录了笔者学习过程中用到的,后续随着学习,会逐渐补充。
1. 回调函数的理解
笔者目前的理解,回调函数是在程序运行中,满足某些要求,就会触发的函数。
2. tf.keras.callbacks
2.1 TensorBoard
官网地址点击此处
源码点击此处
Tensorboard是TensorFlow内置的可视化工具,记录TensorBoard中的事件,包括Metrics summary plots(指标摘要图,指标即损失loss或准确度accuracy),Training graph visualization(训练图可视化),Activation histograms(激活直方图), Sampled profiling(采样分析)。
具体的,可实现如下功能:
- 对如损失和准确度等指标进行跟踪并实现可视化。
- 对模型图进行可视化,比如操作和层(ops and layers)
- 查看权重(weight)、偏差(biases)或其他张量(tensor)随时间变化的直方图
- Projecting embeddings to a lower dimensional space(将嵌入物投影到较低维度的空间)。这句还不能理解是什么意思
- 展示图像、文本和音频数据
- 分析一个TensorFlow程序
- 其他~~~
2.1.1 启动
可在命令行中输入如下命令来启动TensorBoard。
tensorboard --logdir=path_to_your_logs
2.1.2 参数
__init__(
log_dir='logs',
histogram_freq=0,
write_graph=True,
write_images=False,
update_freq='epoch',
profile_batch=2,
embeddings_freq=0,
embeddings_metadata=None,
**kwargs
)
- log_dir: 保存TensorBoard解析文件的路径。
- histogram_freq: 计算模型各层的激活度和权重直方图的频率(每个周期中)。如果设置为0,将不计算直方图。必须为直方图可视化指定验证数据(或拆分)。
- write_gragh: 是否在TensorBoard中可视化图形。当write_graph设置为True时,文件可能会变得很大。
- write_images: 是否在TensorBoard中编写模型权重来实现可视化的图片。
- update_freq: 输入“batch”或“epoch”或整数,在每一个batch或epoch或整数个数据(samples)结束后将损失(loss)和指标(metrics)添加到TensorBoard中。
- profile_batch: 对要采样的批次进行分析,计算特征。默认为2,profile_batch为0时禁用。在eager mode中必须使用。
- embeddings_freq: 嵌入层可视化的频率,被设置为0,则不可见
- embeddings_metadata: 字典,它将图层名称映射到文件名,该文件名的文件保存着嵌入层的元数据。
2.2 示例
此处示例来自此处
以下代码为节选
# 回调函数需在拟合之前设置
# 回调函数,使用Tensorboard,earlystopping,ModelCheckpoint
# Tensorboard需要使用一个文件夹
# ModelCheckpoint需要一个文件名
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,
"fashion_mnist_model.h5")
callbacks = [
keras.callbacks.TensorBoard(logdir),
keras.callbacks.ModelCheckpoint(output_model_file,
save_best_only=True),
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)
]
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_valid, y_valid),
callbacks = callbacks)
# epochs设置数据遍历的次数,validation_data用来设置训练中检验模型的测试集
运行后,目录下出现tensorboard文件夹,在命令行中输入tensorboard --logdir=callbacks,其中,callbacks是我们创建的tensorboard文件夹名称。
tensorboard --logdir=callbacks
笔者的环境下输出结果如下
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6006/ (Press CTRL+C to quit)
打开浏览器,输入localhost:6006,进入.
界面如下: