深度学习-TF2.0 MNIST手写字体识别

系列文章
1.机器学习-Window安装Anaconda3
2.机器学习-Windows使用Anaconda3创建tensorflow2.0CPU环境
3.机器学习-数据集-经典MNIST数据集

环境

windows10
tensorflow==2.0.0
python=3.7

数据集

MNIST
keras.datasets.mnist.load_data()

其他依赖

一:下载并安装 graphviz
http://www.graphviz.org/download/#windows

二:
pip install graphviz
pip install pydot

训练代码

来源:《深度学习实战基于TensorFlow 2.0的人工智能开发应用》
mnist_cnn.py
import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np 
import matplotlib.pyplot as plt
from datetime import datetime
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/Library/Fonts/Songti.ttc",size=8)

def gen_datas():
    """生成数据
    参数:
        无
    返回:
        inputs: 训练图像
        outputs: 训练标签
        eval_images: 测试图像
    """
    # 读取MNIST数据集
    (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
    # 获取前1000个图像数据
    train_labels = train_labels[:1000]
    # 获取前1000个评估使用图像
    eval_images = train_images[:1000]
    # 调整图像数据维度,供训练使用
    train_images = train_images[:1000].reshape(-1,28,28,1)/255.0
    return train_images, train_labels, eval_images

def compile_model(model):
    """神经网络参数配置
    参数:
        model: 神经网络实例
    返回:
        无
    """
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )

def create_model():
    """使用keras新建神经网络
    参数:
        无
    返回:
        model: 神经网络实例
    """
    model = tf.keras.Sequential(name="MNIST-CNN")
    # 卷积层-1
    model.add(
        layers.Conv2D(32, (3,3),
        padding="same",
        activation=tf.nn.relu,
        input_shape=(28,28,1),
        name="conv-1")
        )
    # 最大池化层-1
    model.add(
        layers.MaxPooling2D(
            (2,2),
            name="max-pooling-1"
        )
    )
    # 卷积层-2
    model.add(
        layers.Conv2D(64, (3,3),
        padding="same",
        activation=tf.nn.relu,
        name="conv-2")
        )
    # 最大池化层-2
    model.add(
        layers.MaxPooling2D(
            (2,2),
            name="max-pooling-2"
        )
    )
    # 全连接层-1
    model.add(layers.Flatten(name="fullc-1"))
    # 全连接层-2
    model.add(
        layers.Dense(512,
        activation=tf.nn.relu,
        name="fullc-2")
    )
    # 全连接层-3
    model.add(
        layers.Dense(10,
        activation=tf.nn.softmax,
        name="fullc-3")
    )
    # 配置损失计算及优化器
    compile_model(model)
    return model

def display_nn_structure(model, nn_structure_path):
    """展示神经网络结构
    参数:
        model: 神经网络对象
        nn_structure_path: 神经网络结构保存路径
    返回:
        无
    """
    model.summary()
    keras.utils.plot_model(model, nn_structure_path, show_shapes=True)

def callback_only_params(model_path):
    """保存模型回调函数
    参数:
        model_path: 模型文件路径
    返回:
        ckpt_callback: 回调函数
    """
    ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=model_path,
        verbose=1,
        save_weights_only=True,
        save_freq='epoch'
    )
    return ckpt_callback

def tb_callback(model_path):
    """保存Tensorboard日志回调函数
    参数:
        model_path: 模型文件路径
    返回:
        tensorboard_callback: 回调函数
    """
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=model_path,
        histogram_freq=1,
        write_images=True)
    return tensorboard_callback


def train_model(model, inputs, outputs, model_path, log_path):
    """训练神经网络
    参数:
        model: 神经网络实例
        inputs: 输入数据
        outputs: 输出数据
        model_path: 模型文件路径
        log_path: 日志文件路径
    返回:
        无
    """
    # 回调函数
    ckpt_callback = callback_only_params(model_path)
    # tensorboard回调
    tensorboard_callback = tb_callback(log_path)
    # 保存参数
    model.save_weights(model_path.format(epoch=0))
    summary_writer = tf.summary.create_file_writer(log_path)
    with summary_writer.as_default():
        tf.summary.image("MNIST handwriters", inputs, max_outputs=10, step=0)
    # 训练模型,并使用最新模型参数
    history = model.fit(
            inputs,
            outputs,
            epochs=20,
            callbacks=[ckpt_callback, tensorboard_callback],
            verbose=0
            )
    # 绘制图像
    # plot_history(history)

def train_model_global(model, inputs, outputs, model_path):
    """训练神经网络
    参数:
        model: 神经网络实例
        inputs: 输入数据
        outputs: 输出数据
        model_path: 模型文件路径
        log_path: 日志文件路径
    返回:
        无
    """
    # 训练模型,并使用最新模型参数
    history = model.fit(
            inputs,
            outputs,
            epochs=20,
            verbose=1
            )
    # 保存参数
    model.save(model_path)
    

def load_model(model, model_path):
    """载入模型
    参数:
        model: 神经网络实例
        model_path: 模型文件路径
    返回:
        无
    """
    # 检查最新模型
    latest = tf.train.latest_checkpoint(model_path)
    print("latest:{}".format(latest))
    # 载入模型
    model.load_weights(latest)

def prediction(model, model_path, inputs):
    """神经网络预测
    参数:
        model: 神经网络实例
        model_path: 模型文件路径
        inputs: 输入数据
    返回:
        pres: 预测值
    """
    # 载入模型
    load_model(model, model_path)
    # 预测值
    pres = model.predict(inputs)
    # print("prediction:{}".format(pres))
    # 返回预测值
    return pres

def confusion_matrix(model, model_path, inputs, evals):
    """混淆矩阵可视化
    参数:
        model: 神经网络实例
        inputs: 输入数据
        evals: 测试数据
        model_path: 模型文件路径
    返回:
        evals: 评估数据标签值
        pres: 预测值
        confusion_mat: 混淆矩阵
    """
    # 预测值
    pres = prediction(model, model_path, inputs)
    pres = tf.math.argmax(pres, 1)
    confusion_mat = tf.math.confusion_matrix(evals, pres)
    # 获取矩阵维度
    num = tf.shape(confusion_mat)[0]
    # 迭代添加文本
    for row in range(num):
        for col in range(num):
            plt.text(row, col, confusion_mat[row][col].numpy())
    # 图像写入绘图区
    plt.imshow(confusion_mat, cmap=plt.cm.Blues)
    # 添加标题
    plt.title("手写字体识别混淆矩阵",fontproperties=font)
    # 保存图像
    plt.savefig("./images/confusion_matrix.png", format="png", dpi=300)
    # 展示图像
    plt.show()
    return evals, pres, confusion_mat



def plot_prediction(model, model_path, inputs, evals):
    """可视化预测结果
    参数:
        model: 神经网络实例
        inputs: 输入数据
        evals: 测试数据
        model_path: 模型文件路径
    返回:
        无
    """
    # 预测值
    pres = prediction(model, model_path, inputs)
    pres = tf.math.argmax(pres, 1)
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.subplots_adjust(wspace=0.5, hspace=0.8)
        plt.imshow(evals[i], cmap=plt.cm.binary)
        plt.title("预测值:{}".format(pres[i]), fontproperties=font)
    plt.savefig("./images/cnn-pre.png", format="png", dpi=300)
    plt.show()

if __name__ == "__main__":
    stamp = datetime.now().strftime("%Y%m%d_%H_%M_%S")  # 改动地方
    model_path = "./models/cnn/mnist-cnn"+stamp
    model_path_global = "./models/cnn-global/mnist-cnn"+stamp+".h5"
    # log_path = "./logs/cnn/mnist-cnn"+stamp
    log_path = ".\\logs\\cnn\\mnist-cnn" + stamp # 改动地方

    inputs, outputs, evals = gen_datas()
    print("inputs shape:",inputs.shape)
    print("outputs shape:", outputs.shape)
    # 载入完整模型
    # test_images = tf.convert_to_tensor([inputs[0]])
    # model = tf.keras.models.load_model("./models/cnn-global/mnist-cnn20200321-18:23:12.h5")
    # pre = model.predict(test_images)
    # pre = tf.math.argmax(pre, 1)
    # print("prediction:{}".format(pre))
    # 只载入模型参数
    # 新建网络结构
    model = create_model()
    display_nn_structure(model, "./images/cnn-structure.png")
    # 训练模型:只保存权重
    train_model(model, inputs, outputs, model_path, log_path)
    # 训练模型:保存完整模型
    # train_model_global(model, inputs, outputs, model_path_global)
    model_path = "./models/cnn/"
    # test_images = tf.convert_to_tensor([inputs[0]])
    load_model(model, model_path)
    # pre = model.predict(test_images)
    # pre = tf.math.argmax(pre, 1)
    # print("prediction:{}".format(pre))
    # test_images = tf.convert_to_tensor([inputs[0]])
    # print(test_images)
    # pres = prediction(model, model_path, inputs[:10])
    # print("prediciton:{}".format(tf.math.argmax(pres,1)))
    # plot_prediction(model, model_path, inputs[:16], evals[:16])
    # print(model.weights)
    # 混淆矩阵
    # evals, pres, confusion_mat = confusion_matrix(model, model_path, inputs[:16], outputs[:16])
    # print("手写字体标签值:{}".format(evals))
    # print("手写字体预测值:{}".format(pres))
    # print("混淆矩阵:{}".format(confusion_mat))
    # 获取模型权重参数
    for weight in model.weights:
        print("name:", weight.name)
        print("model weights:", weight.numpy())

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值