基本思路:将一个一个的 机器学习模型 发布成Web API的功能,方便各个平台调用。
因为第一次这么做,所以遇到很多困难,所以在这里记录下来,已被后面查阅,同时也是给其他有心人一个参考
实现步骤
1 安装 anaconda 。官网地址:Anaconda 官网下载地址
安装anaconda 分为 windows 平台安装 和 linux平台安装。windows平台安装 简单,注意设置环境变量就好了。linux平台安装的话(我得是ubuntu 24)需要使用命令安装
2 创建虚拟环境并激活该虚拟环境
因为Tensorflow版本不同,和他配合的各个软件包 也是需要不同的版本,所以需要一个虚拟Python环境
使用 Anaconda Prompt 安装,命令:conda create -name tensor python=3.9
激活tensor 环境。命令:conda activate tensor
4 修改anaconda 默认软件包下载地址
命令:
【 conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/】
【conda config --set show_channel_urls yes】
4.1 安装Tensorflow
命令:pip install tensorflow
5 安装Django
pip install Django
6 在指定的位置创建网站(最小能启动的web应用程序)
用 命令行 切换到你的项目地址,然后输入命令:
django-admin startproject HelloWorld
于是,程序自动在你的项目位置创建了一个简单的网站,
最后,你可以在命令提示符下输入命令启动网站:
python manage.py runserver 0.0.0.0:8000
Django 教程请参考网址【Django 创建第一个项目 | 菜鸟教程】
7 编写能够手写数字识别的 method方便WebAPI调用
关键配置:在settings.py文件中 注销 django.middleware.csrf.CsrfViewMiddleware,否则Web API 针对POST 请求会报给你 403错误
代码如下:
Web API 代码如下
#下面代码是在 views.py中增加的
from django.http import HttpResponse
from .keras_model import keras_model
from django.views.decorators import csrf
from django.shortcuts import render
from pathlib import Path
import os
from rest_framework import generics
import numpy as np
def predic_img(request):
print('enter predic_img')
ctx ={}
if request.POST:
print('post')
#username=request.POST.get('username','')
#password=request.POST.get('password','')
# 针对base64 进行解码,解码后还原图片,并进行预测
hex_string = request.POST.get('image')
byte_array = bytearray.fromhex(hex_string)
bytes_object = bytes(byte_array) # 将bytearray转换为bytes
BASE_DIR = Path(__file__).resolve().parent.parent
path = os.path.join(BASE_DIR, 'images')
print(f'path = {path}')
with open(f'{path}\img.png', 'wb') as file:
file.write(bytes_object)
# 将 bytes 转换成 numpy的 ndarray
print(f'type of bytes is {type(bytes_object)}')
nd_array = np.frombuffer(bytes_object, dtype=np.uint8)
nd_array = pad(nd_array,(1,28,28))
# 预测
obj = keras_model()
ret = obj.predictByImg(nd_array)
ret = '{"result":"%s"}' % (ret)
ctx['rlt'] = ret
else:
print('no post')
return render(request, "post.html", ctx)
# 如果数组和要变形的数组不匹配,需要补齐数组
def pad(source, shape):
out = np.zeros(shape)
print(f'type of source is {type(source)}')
if source.size > out.size :
pass
elif source.size < out.size :
arr_append = np.zeros( out.size - source.size )
#source = source.extend( arr_append )
source = np.append( source,arr_append )
result = source.reshape(shape)
print(f'shape of result is {result.shape}')
return out
预测模型代码
mport tensorflow as tf
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import os
from . import settings
class keras_model:
def __init__(self):
print('这是张强的 keras 测试类')
def run(self):
print('开始运行kera 手写数字识别 模型')
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #数据类型恐怕是 float64
#定义训练模型
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
#定义模型的优化器、损失函数 和 准确率
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#训练
model.fit(x_train, y_train, epochs=5)
#保存模型
model.save('hand_digii_recognition.keras')
print('已经保存模型')
#评估
model.evaluate(x_test, y_test)
print('结束运行kera 手写数字识别 模型')
#预测手写数字
#参数:测试集 索引
def predict(self,index):
index = int(index)
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
model = load_model('./models/hand_digii_recognition.keras')
#print( x_test.shape)
pic = x_test[index-1:index,::]
#for p in pic:
# plt.imshow(p , cmap=plt.get_cmap('gray'))
# plt.show()
#print( pic.shape)
result = model.predict( pic)
#print(f'result = {result}')
#print(f'the max of ndarray is {np.max(result)}')
#找出结果中最大 值 的索引,即预测的数字
#print(np.argmax(result))
return np.argmax(result)
#预测图片数字
#参数:图片张量
def predictByImg(self,pic):
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
model = load_model(f'{settings.BASE_DIR}/models/hand_digii_recognition.keras')
current_path = os.path.abspath(__file__)
print("当前脚本的路径:", current_path)
#model = load_model('d:/hand_digii_recognition.keras')
result = model.predict( pic)
return np.argmax(result)
def save(self):
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
pic = x_train[0:1,::]
plt.imshow(pic , cmap=plt.get_cmap('gray'))
plt.show()
8 客户端代码(我是Winform的 )
注意:我是从npz包中取出了一张图片,用于这个项目的
// 核心代码
private void DoPost(String pic_path)
{
String url = String.Format("http://127.0.0.1:8000/predic_img/");
String res = "";
//String postData = String.Format(@"username={0}&password={1}",username ,password);
HttpWebRequest req = (HttpWebRequest)WebRequest.Create(new Uri(url));
req.AllowAutoRedirect = true;
//req.CookieContainer = Token;
req.Method = "POST";
req.Timeout = 10000;
req.ContentType = "application/x-www-form-urlencoded;charset=utf-8";
//req.ContentType = "application/x-www-form-urlencoded";
try
{
//byte[] data = Encoding.UTF8.GetBytes(postData);
byte[] data = GetBase64FromPic("image",pic_path);
req.ContentLength = data.Length;
Stream stream = req.GetRequestStream();
stream.Write(data, 0, data.Length);
stream.Close();
HttpWebResponse response = (HttpWebResponse)req.GetResponse();
#region 设置cookie
//string setCookie = response.Headers.Get("Set-Cookie");
//Set Cookie
//Get Response Stream
#endregion
StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.UTF8);
res = sr.ReadToEnd();
sr.Close();
response.Close();
}
catch (Exception ex)
{
MessageBox.Show(ex.Message);
}
}
/// <summary>
/// 字节数组转16进制字符串:空格分隔
/// </summary>
/// <param name="byteDatas"></param>
/// <returns></returns>
public static string ToHexStrFromBytes(byte[] byteDatas)
{
StringBuilder builder = new StringBuilder();
for (int i = 0; i < byteDatas.Length; i++)
{
builder.Append(string.Format("{0:X2} ", byteDatas[i]));
}
return builder.ToString().Trim();
}
9 这样就完成了 客户单发送手写图片,服务端接收图片,然后进行预测并给出结果