利用flask深度学习模型部署web应用部署(MINST)
通过模型部署3/3-手把手实现利用flask深度学习模型部署
地址:https://zhuanlan.zhihu.com/p/273252334
Anaconda 先创建,安装基础环境
conda create -n TF1.3.0 python=3.7
然后进入创建好的环境
conda activate TF1.3.0
安装相关软件包,我这里安装了cudatoolkit-10.1安装因为大SDk10.1—如果报错就是不支持英伟达
可以直接跳过
conda install cudatookit=10.1
conda install cudnn=7.6
安装Tensorflow这里我指定了1.3.0版本 ==》这里创建失败就不继续下去了
pip install tensorflow==1.3.0
我这里直接使用了上次搭建TF2.1的环境
可去mooc Tensorflow笔记学习
因为这里from scipy.misc import imread 会出错,
这里因为最新1.3.0版本的scipy没有这个模块了,你要用scipy.misc的imread就暗转scipy1.2.0版本就行。
pip install scipy==1.2.0,然后from scipy.misc import imread就行了
pip install scipy==1.2.0
我这里成功安装版本为scipy-1.4.1然后tf.Session报错 这里报错:
AttributeError: module ‘tensorflow’ has no attribute ‘Session’。
这其实不是安装错误,是因为在新的Tensorflow 2.0版本中已经移除了Session这一模块,改换运行代码
sess = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
就可以获得与原先相同的输出信息。如果觉得不方便,也可以改换低版本的Tensorflow,
直接用pip即可安装
附上完整Kerasflask.py的代码:
from flask import Flask, render_template, request
from scipy.misc import imread, imresize, imsave
import tensorflow as tf
import numpy as np
import re
import base64
from tensorflow.keras.models import load_model
from tensorflow.python.keras.backend import set_session
# 1. 初始化 flask app
app = Flask(__name__)
# 2. 初始化global variables
# sess = tf.Session()因为这里是高版本TF2.1改为以下代码
sess = tf.compat.v1.Session()
# graph = tf.get_default_graph()因为这里是高版本TF2.1改为以下代码
graph = tf.compat.v1.get_default_graph()
# 3. 将用户画的图输出成output.png
def convertImage(imgData1):
imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
with open('output.png', 'wb') as output:
output.write(base64.b64decode(imgstr))
# 4. 搭建前端框架
@app.route('/')
def index():
return render_template("index.html")
# 5. 定义预测函数
@app.route('/predict/', methods=['GET', 'POST'])
def predict():
# 这个函数会在用户点击‘predict’按钮时触发
# 会将输出的output.png放入模型中进行预测
# 同时在页面上输出预测结果
imgData = request.get_data()
convertImage(imgData)
# 读取图片
x = imread('output.png', mode='L')
# 设置图片的规格
x = imresize(x, (28, 28))/255
# 可以保存最终处理好的图片
imsave('final_image.jpg', x)
x = x.reshape(1, 28, 28, 1)
# 调用训练好的模型和并进行预测
global graph
global sess
with graph.as_default():
set_session(sess)
model = load_model('model.h5')
out = model.predict(x)
response = np.argmax(out, axis=1)
return str(response[0])
# 6. 返回本地访问地址
if __name__ == "__main__":
# 让app在本地运行,定义了host和port
app.run(host='0.0.0.0', port=5000)
报错 model_config = json.loads(model_config.decode(‘utf-8’))
AttributeError: ‘str’ object has no attribute ‘decode’
解决方法如下
pip install tensorflow h5py==2.10.0
成功
这里附上文件的结构: