目录
背景
先定性,带AI识别的生鲜收银机早就上市了,目前学习的只能说是别人玩剩的,但是依然收获满满,算是第一个ai识别的应用吧,关键是技术栈全是前端的,还是有一定参考价值
原理
自己的果蔬的ai识别系统我们参考了市面上的一套成熟方案:
识别过程:
学习过程:
技术选型
技术栈
- python
其实一开始是准备使用python的,因为性能方面更快,也更好写,开源的模型也多,但是不得不面对一个问题,就是我们的硬件是一套安卓系统,如果用python的话根据厂家的事例,要么将编写好的python程序转成c++的,要么用接口部署python服务后以restful接口形式提供。第一个方案太不熟了,毕竟还有业务压力,pass。第二个方案中途试了,但考虑到网络损耗,觉得还是部署本地的模型是最优解
可以附上pyhton部分代码:
模型就是用的mobilenet_fv.h5
识别部分:没有使用数据库功能,本地提供了一个缓存机制,半成品
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import collections
import hashlib
from heapq import nlargest
# 初始化缓存,使用OrderedDict保持插入顺序,方便LRU策略
cache = collections.OrderedDict()
# 缓存容量
CACHE_CAPACITY = 1000
model = MobileNet(weights='imagenet', include_top=False, pooling='avg')
def cosine_similarity(features1, features2):
dot_product = np.dot(features1.flatten(), features2.flatten())
norm_features1 = np.linalg.norm(features1)
norm_features2 = np.linalg.norm(features2)
return dot_product / (norm_features1 * norm_features2)
def query_cache(image_features):
if not cache:
return None
max_similarity = -1
best_label = None
for image_id, (cached_features, label) in cache.items():
similarity = cosine_similarity(image_features, cached_features)
if similarity > max_similarity:
max_similarity = similarity
best_label = label
if max_similarity >= 0.5:
return best_label
else:
return None
def query_cache_top5(image_features):
if not cache:
return []
top_similarities = []
for image_id, (cached_features, label) in cache.items():
similarity = cosine_similarity(image_features, cached_features)
top_similarities.append((f"{
similarity:.4f}", label))
# 使用 nlargest 获取前5个最大的相似度
top_5_similarities = nlargest(5, top_similarities)
return top_5_similarities
def get_image_id_from_hash(img):
buffer = img.tobytes()
return hashlib.md5(buffer).hexdigest()
class MainDetect:
# 初始化
def __init__(self):
super().__init__()
# 模型初始化
self.image_id = None
self.image_features = None
self.model = tf.keras.models.load_model("models/mobilenet_fv.h5") # todo 修改模型名称
self.class_names = ['哈密瓜', '柠檬', '桂圆', '梨', '榴莲', '火龙果', '猕猴桃', '胡萝卜', '芒果', '苦瓜',
'草莓', '荔枝', '菠萝', '车厘子', '黄瓜'] # todo 修改类名,这个数组在模型训练的开始会输出
# 预测图片
def predict_img(self, image_data):
img = Image.open(io.BytesIO(image_data))
self.image_id = get_image_id_from_hash(img)
img = np.asarray(img) # 将图片转化为numpy的数组
img = cv2.resize(img, (224, 224))
img_cropped = img[:, :, :3]
target = img_cropped.reshape(1, 224, 224, 3)
start_time = time.time() # 记录开始时间
outputs = self.model.predict(target, batch_size=1, ) # 将图片输入模型得到结果
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算时间差
self.image_features = outputs
result = query_cache_top5(outputs)
# if len(result) == 0:
# result_index = int(np.argmax(outputs))
# result = [["1.0000", self.class_names[result_index]]] # 获得对应的水果名称
return {
"result": result, "outputs": outputs, "time": f"{
elapsed_time * 1000:.2f}ms"}
def classify_image(self, image_data):
img = Image.open(io.BytesIO(image_data))
self.image_id = get_image_id_from_hash(img)
img = np.asarray(img) # 将图片转化为numpy的数组
img = cv2.resize(img, (224, 224))
img_cropped = img[:, :, :3]
target = img_cropped.reshape(1, 224, 224, 3)
# 进行预测
start_time = time.time() # 记录开始时间
outputs = model.predict(target)
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算时间差
self.image_features = outputs
result = query_cache_top5(outputs)
# if len(result) == 0:
# result_index = int(np.argmax(outputs))
# result = [["1.0000", self.class_names[result_index]]] # 获得对应的水果名称
return {
"result": result, "outputs": outputs, "time": f"{
elapsed_time * 1000:.2f}ms"}
def update_cache(self, label):
# 如果缓存已满,移除最久未使用的条目
if len(cache) >= CACHE_CAPACITY:
cache.popitem(last=False)
# 添加新条目
cache[self.image_id] = (self.image_features, label)
self.class_names.append(label)
return True
def clear_cache(slef):
cache.clear()
return True
服务部分:
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
import numpy as np
app = Flask(__name__)
CORS(app) # 允许所有路由上的跨域请求
detector = MainDetect()
@app.route('/')
def home():
return "Welcome to the Vegetable Recognize App!"
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return 'No file part', 400
file = request.files['file']
if file.filename == '':
return 'No selected file', 400
try:
image_data = file.read()
data = detector.classify_image(image_data)
result = data["result"]
outputs = data["outputs"]
time = data["time"]
return jsonify({
"top5": result, "time": time, "features": outputs.tolist()})
except Exception as e:
return jsonify({
'error': str(e)}