代码实例,基于TensorFlow+Flask部署MNIST手写数字识别至本地web,希望对您有所帮助。
环境说明
操作系统:Windows 10
CUDA 版本为: 10.0
cudnn 版本为: 7.6.5
Python 版本为:Python 3.6.13
tensorflow-gpu 版本为:1.13.1
keras 版本为:2.2.4
flask版本为:2.0.3
注意CUDA、cudnn、Python、tensorflow版本之间的匹配
文件结构
.
│ model.h5
│ model.json
│ webMNIST.py
│
├─static
│ index.js
│ jquery-3.2.0.min.js
│ style.css
│
└─templates
index.html
static 和 templates 的源文件来自:https://github.com/ybsdegit/Keras_flask_mnist
其中,index.html 文件做了一点点修改,修改标题为 MNIST手写数字识别,修改 GitHub Corner。
模型训练
# =====-*- coding: utf-8 -*-=====
# @Time : 2022/3/13 20:12
# @Author: AXYZdong
# @File : Keras.py
# @IDE : Pycharm
# ===============================
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from tensorflow.examples.tutorials.mnist import input_data
# 独热编码
data_folder = "./MNIST_data"
mnist = input_data.read_data_sets(data_folder, one_hot=True)
# 搭建CNN卷积神经网络
model = Sequential()
inputShape = (28, 28, 1) # 输入数据的维度
# 第一层:卷积层+池化层
model.add(Conv2D(32, (5, 5), activation='relu', padding='same', input_shape=inputShape)) # 卷积层
model.add(MaxPooling2D(pool_size=(2, 2))) # 最大池化
# 第二层:卷积层+池化层
model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# 第三层:卷积层+池化层
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# 第四层:卷积层+池化层
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten()) # 将数据展开成一维
model.add(Dense(1024, activation='relu')) # 全连接网络
model.add(Dropout(0.5)) # Dropout处理,防止过拟合
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5)) # Dropout处理,防止过拟合
model.add(Dense(10, activation='softmax')) # 分类
# 编译模型
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(mnist.train.images.reshape((-1, 28, 28, 1)), mnist.train.labels, batch_size=64, epochs=5, verbose=1)
# 保存模型
model.save('model.h5')
本地web创建
# =====-*- coding: utf-8 -*-=====
# @Time : 2022/4/18 15:52
# @Author: AXYZdong
# @File : webMNIST.py
# @IDE : Pycharm
# ===============================
from flask import Flask, render_template, request
import numpy as np
import tensorflow.keras as keras
import re
import base64
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.python.keras.backend import set_session
import tensorflow as tf
app = Flask(__name__)
sess = tf.Session()
graph = tf.get_default_graph()
set_session(sess)
model_file = 'model.h5'
global model
model = keras.models.load_model(model_file)
@app.route('/')
def index():
return render_template("index.html")
@app.route('/predict/', methods=['Get', 'POST'])
def predict():
parseImage(request.get_data())
img = img_to_array(load_img('output.png', target_size=(28, 28), color_mode="grayscale")) / 255.
img = np.expand_dims(img, axis=0)
global graph
global sess
with graph.as_default():
set_session(sess)
predictions = model.predict_classes(img)[0]
return str(predictions)
def parseImage(imgData):
imgStr = re.search(b'base64,(.*)', imgData).group(1)
with open('./output.png', 'wb') as output:
output.write(base64.decodebytes(imgStr))
if __name__ == '__main__':
app.run(host="0.0.0.0", port=3335)
实现效果
参考文献
- [1] https://blog.csdn.net/qq_38534107/article/details/103565899
- [2] https://github.com/ybsdegit/Keras_flask_mnist
如果以上内容有任何错误或者不准确的地方,欢迎在下面 👇 留言。或者你有更好的想法,欢迎一起交流学习~~~
更多精彩内容请前往 AXYZdong的博客