mnist手写数字识别--python3.8 tensorflow cpu 2.3.0

mnist手写数字识别–python3.8 tensorflow cpu 2.3.0

原文链接:https://geektutu.com/post/tensorflow-mnist-simplest.html

本文介绍了机器学习中的hello word ----------mnist 😗

原文采用:

python 3.6
tensorflow 1.4

本文采用:(截至2020年10月最高支持tensorflow版本,部分代码略作修改)

python 3.8.5(64位)
tensorflow cpu 2.3.0

先上结果:
loss:
在这里插入图片描述accuracy:
在这里插入图片描述

神经网络:
在这里插入图片描述
在这里插入图片描述

bias and weight distribution:
在这里插入图片描述

Histograms:
在这里插入图片描述
在这里插入图片描述

代码:

model.py

import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

class Network:
    def __init__(self):
        self.learning_rate = 0.001
        self.global_step = tf.Variable(0, trainable=False, name="global_step")

        self.x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="x")
        self.label = tf.compat.v1.placeholder(tf.float32, [None, 10], name="label")

        self.w = tf.Variable(tf.zeros([784, 10]), name="fc/weight")
        self.b = tf.Variable(tf.zeros([10]), name="fc/bias")
        self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b, name="y")

        self.loss = -tf.reduce_sum(self.label * tf.compat.v1.log(self.y + 1e-10))
        self.train = tf.compat.v1.train.GradientDescentOptimizer(self.learning_rate).minimize(
            self.loss, global_step=self.global_step)

        predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(predict, "float"))

        # 创建 summary node
        # w, b 画直方图
        # loss, accuracy画标量图
        tf.compat.v1.summary.histogram('weight', self.w)
        tf.compat.v1.summary.histogram('bias', self.b)
        tf.compat.v1.summary.scalar('loss', self.loss)
        tf.compat.v1.summary.scalar('accuracy', self.accuracy)

train.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()


CKPT_DIR = 'ckpt'


class Train:
    def __init__(self):
        self.net = Network()
        self.sess = tf.compat.v1.Session()
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.data = input_data.read_data_sets('./data_set', one_hot=True)

    def train(self):
        batch_size = 64
        train_step = 20000
        step = 0
        save_interval = 1000
        saver = tf.compat.v1.train.Saver(max_to_keep=5)

        # merge所有的summary node
        merged_summary_op = tf.compat.v1.summary.merge_all()
        # 可视化存储目录为当前文件夹下的 log
        merged_writer = tf.compat.v1.summary.FileWriter("./log", self.sess.graph)

        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess, ckpt.model_checkpoint_path)
            # 读取网络中的global_step的值,即当前已经训练的次数
            step = self.sess.run(self.net.global_step)
            print('Continue from')
            print('        -> Minibatch update : ', step)

        while step < train_step:
            x, label = self.data.train.next_batch(batch_size)
            _, loss, merged_summary = self.sess.run(
                [self.net.train, self.net.loss, merged_summary_op],
                feed_dict={self.net.x: x, self.net.label: label}
            )
            step = self.sess.run(self.net.global_step)

            if step % 100 == 0:
                merged_writer.add_summary(merged_summary, step)

            if step % save_interval == 0:
                saver.save(self.sess, CKPT_DIR + '/model', global_step=step)
                print('%s/model-%d saved' % (CKPT_DIR, step))

    def calculate_accuracy(self):
        test_x = self.data.test.images
        test_label = self.data.test.labels
        accuracy = self.sess.run(self.net.accuracy,
                                 feed_dict={self.net.x: test_x, self.net.label: test_label})
        print("准确率: %.2f,共测试了%d张图片 " % (accuracy, len(test_label)))


if __name__ == "__main__":
    app = Train()
    app.train()
    app.calculate_accuracy()
# tensorboard --logdir=./log

predict.py

在这里插import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
from model import Network


# python 3.6
# tensorflow 1.4
# pillow(PIL) 4.3.0
# 使用tensorflow的模型来预测手写数字
# 输入是28 * 28像素的图片,输出是个具体的数字


CKPT_DIR = 'ckpt'


class Predict:
    def __init__(self):
        self.net = Network()
        self.sess = tf.compat.v1.Session()
        self.sess.run(tf.compat.v1.global_variables_initializer())

        # 加载模型到sess中
        self.restore()

    def restore(self):
        saver = tf.compat.v1.train.Saver()
        ckpt = tf.compat.v1.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            raise FileNotFoundError("未保存任何模型")

    def predict(self, image_path):
        # 读图片并转为黑白的
        img = Image.open(image_path).convert('L')
        flatten_img = np.reshape(img, 784)
        x = np.array([1 - flatten_img])
        y = self.sess.run(self.net.y, feed_dict={self.net.x: x})

        # 因为x只传入了一张图片,取y[0]即可
        # np.argmax()取得独热编码最大值的下标,即代表的数字
        print(image_path)
        print('        -> Predict digit', np.argmax(y[0]))


if __name__ == "__main__":
    app = Predict()
    app.predict('./test_images/0.png')
    app.predict('./test_images/1.png')
    app.predict('./test_images/4.png')
入代码片

测试图片:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
文件设置:
在这里插入图片描述
data_set 为训练的MNIST库,运行代码自动生成
ckpt为保存模型,运行代码自动生成
log为可视化路径,运行代码自动生成
test_images为识别图片的地址

在project路径运行tensorboard:

tensorboard --logdir=./log

浏览器访问localhost得到可视化结果,端口6006(具体见cmd运行结果):http://localhost:6006/

在这里插入图片描述
pycharm中显示结果:

train.py:

在这里插入图片描述

predict.py:
在这里插入图片描述

  • 2
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值