mnist手写数字识别–python3.8 tensorflow cpu 2.3.0
原文链接:https://geektutu.com/post/tensorflow-mnist-simplest.html
本文介绍了机器学习中的hello word ----------mnist 😗
原文采用:
python 3.6 tensorflow 1.4
本文采用:(截至2020年10月最高支持tensorflow版本,部分代码略作修改)
python 3.8.5(64位) tensorflow cpu 2.3.0
先上结果:
loss:
accuracy:
神经网络:
bias and weight distribution:
Histograms:
代码:
model.py
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
class Network:
def __init__(self):
self.learning_rate = 0.001
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="x")
self.label = tf.compat.v1.placeholder(tf.float32, [None, 10], name="label")
self.w = tf.Variable(tf.zeros([784, 10]), name="fc/weight")
self.b = tf.Variable(tf.zeros([10]), name="fc/bias")
self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b, name="y")
self.loss = -tf.reduce_sum(self.label * tf.compat.v1.log(self.y + 1e-10))
self.train = tf.compat.v1.train.GradientDescentOptimizer(self.learning_rate).minimize(
self.loss, global_step=self.global_step)
predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(predict, "float"))
# 创建 summary node
# w, b 画直方图
# loss, accuracy画标量图
tf.compat.v1.summary.histogram('weight', self.w)
tf.compat.v1.summary.histogram('bias', self.b)
tf.compat.v1.summary.scalar('loss', self.loss)
tf.compat.v1.summary.scalar('accuracy', self.accuracy)
train.py
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
CKPT_DIR = 'ckpt'
class Train:
def __init__(self):
self.net = Network()
self.sess = tf.compat.v1.Session()
self.sess.run(tf.compat.v1.global_variables_initializer())
self.data = input_data.read_data_sets('./data_set', one_hot=True)
def train(self):
batch_size = 64
train_step = 20000
step = 0
save_interval = 1000
saver = tf.compat.v1.train.Saver(max_to_keep=5)
# merge所有的summary node
merged_summary_op = tf.compat.v1.summary.merge_all()
# 可视化存储目录为当前文件夹下的 log
merged_writer = tf.compat.v1.summary.FileWriter("./log", self.sess.graph)
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(self.sess, ckpt.model_checkpoint_path)
# 读取网络中的global_step的值,即当前已经训练的次数
step = self.sess.run(self.net.global_step)
print('Continue from')
print(' -> Minibatch update : ', step)
while step < train_step:
x, label = self.data.train.next_batch(batch_size)
_, loss, merged_summary = self.sess.run(
[self.net.train, self.net.loss, merged_summary_op],
feed_dict={self.net.x: x, self.net.label: label}
)
step = self.sess.run(self.net.global_step)
if step % 100 == 0:
merged_writer.add_summary(merged_summary, step)
if step % save_interval == 0:
saver.save(self.sess, CKPT_DIR + '/model', global_step=step)
print('%s/model-%d saved' % (CKPT_DIR, step))
def calculate_accuracy(self):
test_x = self.data.test.images
test_label = self.data.test.labels
accuracy = self.sess.run(self.net.accuracy,
feed_dict={self.net.x: test_x, self.net.label: test_label})
print("准确率: %.2f,共测试了%d张图片 " % (accuracy, len(test_label)))
if __name__ == "__main__":
app = Train()
app.train()
app.calculate_accuracy()
# tensorboard --logdir=./log
predict.py
在这里插import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
from model import Network
# python 3.6
# tensorflow 1.4
# pillow(PIL) 4.3.0
# 使用tensorflow的模型来预测手写数字
# 输入是28 * 28像素的图片,输出是个具体的数字
CKPT_DIR = 'ckpt'
class Predict:
def __init__(self):
self.net = Network()
self.sess = tf.compat.v1.Session()
self.sess.run(tf.compat.v1.global_variables_initializer())
# 加载模型到sess中
self.restore()
def restore(self):
saver = tf.compat.v1.train.Saver()
ckpt = tf.compat.v1.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(self.sess, ckpt.model_checkpoint_path)
else:
raise FileNotFoundError("未保存任何模型")
def predict(self, image_path):
# 读图片并转为黑白的
img = Image.open(image_path).convert('L')
flatten_img = np.reshape(img, 784)
x = np.array([1 - flatten_img])
y = self.sess.run(self.net.y, feed_dict={self.net.x: x})
# 因为x只传入了一张图片,取y[0]即可
# np.argmax()取得独热编码最大值的下标,即代表的数字
print(image_path)
print(' -> Predict digit', np.argmax(y[0]))
if __name__ == "__main__":
app = Predict()
app.predict('./test_images/0.png')
app.predict('./test_images/1.png')
app.predict('./test_images/4.png')
入代码片
测试图片:
文件设置:
data_set 为训练的MNIST库,运行代码自动生成
ckpt为保存模型,运行代码自动生成
log为可视化路径,运行代码自动生成
test_images为识别图片的地址
在project路径运行tensorboard:
tensorboard --logdir=./log
浏览器访问localhost得到可视化结果,端口6006(具体见cmd运行结果):http://localhost:6006/
pycharm中显示结果:
train.py
:
predict.py
: