Tensorboard 对 CNN可视化 (以mnist识别书写数字为例)

本文是卷积神经网络的一篇小入门
原文链接:
https://geektutu.com/post/tensorflow2-mnist-cnn.html

train.py

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

'''
python 3.8.5
tensorflow 2.3.1
'''


class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
        model.add(layers.Conv2D(
            32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))

        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))

        model.summary()

        self.model = model




class DataSource(object):
    def __init__(self):
        # mnist数据集存储的位置,如何不存在将自动下载
        data_path = os.path.abspath(os.path.dirname(
            __file__)) + '/data_set_tf2/mnist.npz'
        (train_images, train_labels), (test_images,
                                       test_labels) = datasets.mnist.load_data(path=data_path)
        # 6万张训练图片,1万张测试图片
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
        # 像素值映射到 0 - 1 之间
        train_images, test_images = train_images / 255.0, test_images / 255.0

        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels

class Train:
    def __init__(self):
        self.cnn = CNN()
        self.data = DataSource()
    def train(self):
        check_path = './ckpt/cp-{epoch:04d}.ckpt'
        # period 每隔5epoch保存一次
        save_model_cb = tf.keras.callbacks.ModelCheckpoint(
            check_path, save_weights_only=True, verbose=1, period=5)
        self.cnn.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        tensorboard_callback=tf.keras.callbacks.TensorBoard(log_dir="./log",histogram_freq=1)
        self.cnn.model.fit(self.data.train_images, self.data.train_labels,
                           epochs=5, validation_data=(self.data.train_images,self.data.train_labels),
                           callbacks=[tensorboard_callback,save_model_cb])
        test_loss, test_acc = self.cnn.model.evaluate(
            self.data.test_images, self.data.test_labels)
        print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
        tf.constant()


if __name__ == "__main__":
    app = Train()
    app.train()

# tensorboard --logdir=./log

predict.py

import tensorflow as tf
from PIL import Image
import numpy as np

from train import CNN

'''
python 3.8
tensorflow 2.3.1
pillow 是个版本就可以
'''


class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('./ckpt')
        self.cnn = CNN()
        # 恢复网络权重
        self.cnn.model.load_weights(latest)

    def predict(self, image_path):
        # 以黑白方式读取图片
        img = Image.open(image_path).convert('L')
        img = np.reshape(img, (28, 28, 1)) / 255.
        x = np.array([1 - img])

        # API refer: https://keras.io/models/model/
        y = self.cnn.model.predict(x)

        # 因为x只传入了一张图片,取y[0]即可
        # np.argmax()取得最大值的下标,即代表的数字
        print(image_path)
        print(y[0])
        print('        -> Predict digit', np.argmax(y[0]))


if __name__ == "__main__":
    app = Predict()
    app.predict('./test_images/0.png')
    app.predict('./test_images/1.png')
    app.predict('./test_images/4.png')

  • 运行train.py:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________
1875/1875 [==============================] - 21s 11ms/step - loss: 0.1553 - accuracy: 0.9509 - val_loss: 0.0456 - val_accuracy: 0.9863
Epoch 2/5
1875/1875 [==============================] - 21s 11ms/step - loss: 0.0469 - accuracy: 0.9852 - val_loss: 0.0374 - val_accuracy: 0.9885
Epoch 3/5
1875/1875 [==============================] - 21s 11ms/step - loss: 0.0341 - accuracy: 0.9898 - val_loss: 0.0220 - val_accuracy: 0.9934
Epoch 4/5
1875/1875 [==============================] - 21s 11ms/step - loss: 0.0259 - accuracy: 0.9915 - val_loss: 0.0194 - val_accuracy: 0.9936
Epoch 5/5
1874/1875 [============================>.] - ETA: 0s - loss: 0.0212 - accuracy: 0.9930
Epoch 00005: saving model to ./ckpt\cp-0005.ckpt
1875/1875 [==============================] - 21s 11ms/step - loss: 0.0212 - accuracy: 0.9930 - val_loss: 0.0140 - val_accuracy: 0.9960
313/313 [==============================] - 1s 3ms/step - loss: 0.0305 - accuracy: 0.9906
准确率: 0.9906,共测试了10000张图片
  • 运行predict.py
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________
./test_images/0.png
[9.9998450e-01 8.4827110e-07 2.9172602e-06 5.5995207e-07 9.1768314e-08
 5.0970073e-07 3.3251815e-06 3.7221660e-08 2.1960172e-06 5.0798021e-06]
        -> Predict digit 0
./test_images/1.png
[1.3399864e-06 9.9998462e-01 1.6632547e-06 9.0319432e-11 7.1405714e-07
 3.7778250e-07 5.2523381e-09 8.7272074e-06 1.4942656e-06 1.0218997e-06]
        -> Predict digit 1
./test_images/4.png
[1.7766638e-07 4.0547842e-05 2.1346721e-05 2.5850220e-08 9.9976021e-01
 7.2087487e-06 4.1368772e-08 8.3897612e-05 1.1050152e-05 7.5568620e-05]
        -> Predict digit 4
  • 在Terminal 中运行如下代码
    tensorboard --logdir=./log
    我们有:
    在这里插入图片描述
  • 打开localhost:6006
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
    +在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 在这里插入图片描述
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值