果蔬识别系统性能优化之路

9 篇文章 0 订阅

一级目录

二级目录

三级目录

前情提要

超详细前端AI蔬菜水果生鲜识别应用优化之路

当前问题

  1. indexddb在webview中确实性能有限,存储量上来后每次读取数据会有明显卡顿
  2. 目前的余弦相邻算法是基于所有特征向量数据进行的计算,一旦数据量大了后计算量也是一个十分消耗性能的点
  3. 由于机器性能问题,本地化加载模型在每次进入页面后需要消耗很长一段时间进行模型的加载,体验相当不好

优化方案

1. 内存化

方式:为了解决indexddb读取速度的问题,最直接的方式就是把数据放内存对象中,提前将indexddb的数据取出来,然后在数据变化时进行同步
结果:确实省去了读取这一步会快很多,但是仍有隐患,webview分配的内存是否这种占用内存的方式

2. 原生化

本地化模型加载可以在原生层面进行模型的加载,识别,学习,相当于把整套方案在原生端实现一次,通过bridge调用原生相关方法完成识别,但目前暂无学习原生的意向,所以搁置

3. 接口化

方式:将数据存储在mysql,识别单独起一个python服务,通过接口调用,利用服务器的性能优势使识别速度提升,保证网络消耗在合理范围即可

行动

最终选择了接口化,在保证网络的情况下,识别速度在300ms内可被接受

  1. nestjs搭建服务端,分feature和img两个模块
  2. mysql搭建数据库,建立feature和img两个表,通过imgId和feature表关联,同时feature表包含storeCode字段,用来管理门店
  3. 搭建redis,代替内存化方案,实现快速读取
  4. 搭建python服务端,与nestjs服务端进行通信,进行识别结果的传输
  5. 通过IVF方式提升计算速度,保证大量特征值情况下仍然可以快速算出相似结果

实现

  1. python端,flask搭建http服务,当然后续可以改成其他和服务端通信方式,提升速度
  2. 实现识别接口恶化同步接口,用于识别图片特征值和同步数据库存储的特征向量
    app.py
from flask import Flask, request, jsonify
from flask_cors import CORS
from detect import MainDetect
from tensorflow.keras import layers, models

app = Flask(__name__)
CORS(app)  # 允许所有路由上的跨域请求
detector = MainDetect()


@app.route('/')
def home():
    return "Welcome to the Vegetable Recognize App!"


@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return 'No file part', 400

    file = request.files['file']
    store_code = request.form["storeCode"]
    if file.filename == '':
        return 'No selected file', 400

    try:
        image_data = file.read()
        data = detector.classify_image(image_data, store_code)
        outputs = data["outputs"]
        time = data["time"]
        index = data["index"]
        return jsonify({"predictTime": time, "features": outputs, "index": index})

    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/sync', methods=['POST'])
def sync():
    data = request.get_json()
    arr = data.get('data')
    store_code = data.get('storeCode')
    detector.sync(store_code, arr)
    return jsonify({"message": 'ok'})

detect.py

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
from PIL import Image
import numpy as np
import cv2
import time
import io
import gc
from ivf import IVFPQ

model = MobileNet(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
        super().__init__()
        # 模型初始化
        self.image_id = None
        self.image_features = None
        self.model = tf.keras.models.load_model("models/custom/my-model.h5")
        self.ivfObj = {}

    def classify_image(self, image_data, store_code):
        # Load and preprocess image
        img = tf.image.decode_image(image_data, channels=3)
        img = tf.image.resize(img, [224, 224])
        img = tf.expand_dims(img, axis=0)  # Add batch dimension

        # Run model prediction
        start_time = time.time()
        outputs = model.predict(img)
        # outputs = self.model.predict(outputs)
        # prediction = tf.divide(outputs, tf.norm(outputs))
        i = []
        if store_code + '-featureDatabase' in self.ivfObj:
            i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
            i = i.flatten().tolist()

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Flatten the outputs and return them
        # output_data = prediction.numpy().flatten().tolist()
        output_data = outputs.flatten().tolist()

        # Force garbage collection to free up memory
        del img, outputs, end_time, start_time  # Ensure variables are deleted
        gc.collect()

        return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code, data):
        if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj[store_code + '-featureDatabase']
        self.ivfObj[store_code + '-featureDatabase'] = IVFPQ()
        for item in data:
            feature = item['features']
            self.ivfObj[store_code + '-featureDatabase'].add(np.array([feature], dtype=np.float32))
        return 'ok'

ivf.py

import faiss
import numpy as np


class IVFPQ:
    def __init__(self, d=1024, nlist=1, m=16, n_bits=8):
        # 创建量化器
        quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化
        self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)

        np.random.seed(1234)
        xb = np.random.random((256, d)).astype('float32')  # 模拟数据库中的特征向量
        # 训练索引
        self.index.train(xb)
        self.index.add(xb)  # 将特征向量添加到索引中

    def search(self, xq, k=50):
        d, i = self.index.search(xq, k)
        return i

    def add(self, xb):
        self.index.add(xb)

    def sync(self, features):
        for i in range(len(features)):
            self.add(features[i])

  1. nestjs进行接口的转发和数据处理
    识别的service,调用python服务,并通过返回的索引找到正确的目标label
  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   */
  async predict(file: Express.Multer.File, num: string = '5', storeCode: string, justPredict: Boolean = false) {
    const url = 'http://localhost:5000/predict'; // Python 服务的 URL
    const startTime = Date.now();
    try {
      // 返回 Python 服务的响应数据
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      const response = await firstValueFrom(this.httpService.post(url, formData));
      const endTime = Date.now();
      const features = response.data.features;
      const index = response.data.index;
      if (justPredict) {
        return features;
      }
      // const top5 = await this.findTopNSimilar(features, parseInt(num), storeCode);

      const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
      if (!featureDatabaseStr) {
        return response.data = {
          ...response.data,
          [`top${num}`]: [],
          features,
          totalTime: `${endTime - startTime}ms`,
        };
      }
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const list = [];
      index.forEach((i: number) => {
        const ide = i - 256;
        if (ide >= 0) {
          const item = featureDatabase[ide];
          if (!list.some(l => l.label === item.label)) {
            list.push({ label: item.label });
          }
        }
      });
      return response.data = {
        ...response.data,
        [`top${num}`]: list,
        features,
        totalTime: `${endTime - startTime}ms`,
      };
    } catch (error) {
      // 错误处理
      console.error('Error calling Python service:', error);
      throw error;
    }
  }

同步python端特征向量,在数据库的增删改查时进行调用

  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(featureDatabase));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }

结语

没规划好,其实全部用python实现应该更自然,后续有时间再更新,先上线跑跑

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值