fit函数 model_深度学习与Tensorflow学习笔记2 ——回调函数callbacks和Tensorboard

    上一期我们从Fashion-mnist数据集开始,使用Tensorflow.keras搭建一个简单的神经网络来处理分类问题。通过这个简单例子我们熟悉了tf.keras的调用。本期我们来学习keras下面的回调函数callbacks的用法。这里,简单的再说一句,Tensorflow有非常完善的官方文档,相当于学习手册。(而且还有中文网站:https://tensorflow.google.cn/)在学习过程中,很多问题都可以通过查看官方文档来解决。另外官方文档上还配有很多学习案例,截图如下:

47c09477ac5500fa3173d4e14869bf76.png

    有了这么完整的学习手册,大家完全可以对照着进行自学。本期我们就带着大家了解一下Tensorflow回调函数callbacks下面一个非常好用的可视化工具Tensorboard的使用。我们可以直接在官网上找到Tensorboard的相关介绍,位于keras.callbacks下面。通过下面这张图,我们可以看到callbacks(复数)下面不止有TensorBoard,还有很多工具。常用的还有有EarlyStopping,ModelCheckpoint等

161bf8b9874d0a11283eda7257486243.png

    简单来说TensorBoard是TensorFlow下的一个监控网络训练的可视化工具,他可以展示Tensorflow 在运行过程中的计算图、各种指标随着时间的变化趋势以及训练中使用到的图像等信息。这样能够帮助研究者们可视化训练大规模神经网络过程中出现的复杂且不好理解的运算,展示训练过程中绘制的图像、网络结构等。下面我们就通过实战来演练一下。

  1.  和之前面的代码和之前一样,把相关库和数据集导入之后,创建网络。

import numpy as np
import pandas as pd
import os
import sys
import time
import sklearn
import tensorflow as tf
from tensorflow import keras

import matplotlib as mpl
import matplotlib.pyplot as plt

# 下面导入Fashion-mnist数据集
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

# 打印数据集形状,看看是否导入成功
print(x_train_all.shape, y_train_all.shape)
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
>>>
(60000, 28, 28) (60000,)
(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)

    2. 在搭建一个4层的神经网络模型:

model = tf.keras.Sequential()
model.add(keras.layers.Flatten(input_shape = [28,28]))
model.add(keras.layers.Dense(400, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))

# 输出层:
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.summary()
>>>
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 400)               314000    
_________________________________________________________________
dense_1 (Dense)              (None, 100)               40100     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1010      
=================================================================
Total params: 355,110
Trainable params: 355,110
Non-trainable params: 0
_________________________________________________________________

3. 创建callbacks列表

    callbacks是在训练过程中对网络的迭代情况进行监控,因而需要将callback的内容放在网络的训练过程中,也就是将定义好的callbacks列表作为一个参数放在model.fit中。我们创建一个包含TensorBoard、ModelCheckpoint和Earlystopping的callbacks。对于Tensorboard需要指定一个文件夹,ModelCheckpoint需要一个文件名。

# callbacks: Tensorboard, earlystopping, ModelCheckpoint
logdir = os.path.join("callbacks" ) # Tensorboard需要一个文件夹
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir,
                                'fashion_mnist_model.h5')


callbacks = [
    keras.callbacks.TensorBoard(logdir),
    keras.callbacks.ModelCheckpoint(output_model_file,
                                   save_best_only= True),
    keras.callbacks.EarlyStopping(patience=5, min_delta= 1e-3),
]
model.compile(loss = 'sparse_categorical_crossentropy',
             optimizer = 'adam',
             metrics = ['accuracy'])

  然后在训练时,将callbacks列表作为一个参数传给model.fit:

history = model.fit(x_train, y_train, epochs = 10, 
                    validation_data = (x_valid, y_valid),
                   callbacks = callbacks)
>>>
Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 12s 214us/sample - loss: 2.3347 - accuracy: 0.7135 - val_loss: 0.7192 - val_accuracy: 0.7390
Epoch 2/10
55000/55000 [==============================] - 10s 181us/sample - loss: 0.6273 - accuracy: 0.7719 - val_loss: 0.5271 - val_accuracy: 0.8044
Epoch 3/10
55000/55000 [==============================] - 10s 181us/sample - loss: 0.5346 - accuracy: 0.8094 - val_loss: 0.5090 - val_accuracy: 0.8220
Epoch 4/10
55000/55000 [==============================] - 10s 185us/sample - loss: 0.4754 - accuracy: 0.8322 - val_loss: 0.4284 - val_accuracy: 0.8514
Epoch 5/10
55000/55000 [==============================] - 10s 180us/sample - loss: 0.4346 - accuracy: 0.8445 - val_loss: 0.4937 - val_accuracy: 0.8384
Epoch 6/10
55000/55000 [==============================] - 10s 179us/sample - loss: 0.4206 - accuracy: 0.8519 - val_loss: 0.4097 - val_accuracy: 0.8642
Epoch 7/10
55000/55000 [==============================] - 10s 183us/sample - loss: 0.4014 - accuracy: 0.8605 - val_loss: 0.4938 - val_accuracy: 0.8440
Epoch 8/10
55000/55000 [==============================] - 10s 184us/sample - loss: 0.3839 - accuracy: 0.8667 - val_loss: 0.4187 - val_accuracy: 0.8644
Epoch 9/10
55000/55000 [==============================] - 10s 183us/sample - loss: 0.3687 - accuracy: 0.8710 - val_loss: 0.4370 - val_accuracy: 0.8598
Epoch 10/10
55000/55000 [==============================] - 10s 182us/sample - loss: 0.3647 - accuracy: 0.8737 - val_loss: 0.3953 - val_accuracy: 0.8702

4. 打开tensorboard,并在浏览器中查看

训练结束之后,我们在之前定义logdir文件夹位置是创建的callbacks的文件夹下发现多了几个文件:5b041db238e32bcee524a9b52c16b43f.png

「下面重点来了, 怎么打开查看?不是说Tensorboard是可视化的工具?怎么个可视化?」这也是本期的重点部分,怎么打开Tensor board? 

    有两种方法,一种是打开windows的命令行工具,cd到callbacks的「上一级目录」, 之后再输入tensorboard --logdir=callbacks运行tensorboard。或者使用Anaconda自带的prompt工具打开。这里我们以prompt为例:如下图

498b1a527f1fa6cab0aa27822945233b.png

    中间可能因为环境或者版本问题会有警告之类的,但是只要看到最后给了一个地址:localhost:6066就成功了。这里相当于tensorboard在本地构建了一个服务器,端口是6066。我们直接在浏览器输入localhost:6066即可打开。如下图

bf2d082b4f0a54bd6abfb6466468f61c.png    是不是很激动,有木有!!在GRAPHS里面还有网络的结构,这个在复杂的网络的结构里面非常好用cad1c906d89f6b7f09ffc93ec9d3cbc9.png

  这里要强调两个坑:大家在使用的时候注意避免。

  1. 命令行输入tensorboard --logdir=callbacks时,要注意:
  • callbacks不需要加上引号
  • logdir之后的“=”号两端千万不能有空格。这个和python输入代码时还是有些不一样
  1. 需要在chrome/firefox浏览器中打开,在其他浏览可能会出现打不开的情况!

  好了,本期就介绍到这里,赶紧试试tensorboard这个工具吧

274963c6e39cf206cf22e91fa94a4ddd.gif

End

老狼| 我是老狼,机械行业十年老兵。我正在学习Python,期望使用Python提高工作效率,给生活带来更多可能。欢迎扫码关注和我交流,我会在公众号内给您分享一些使用Python来提高效率的真实案例

bfb46afbaaa216259460b6701bdbd99d.png

扫一扫关注公众号获取视频学习资料,不断更新中!

先有收获,再点在看!
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值