果蔬识别系统性能优化之路(二)

前情提要

果蔬识别系统性能优化之路

新的问题

上期试用后发现识别特别慢,通过断点和耗时打印发现虽然redis很快,担当我们把大量数据放redis进行读写时redis依然会有性能瓶颈,尤其是把特征向量这种庞大的数据集进行频繁读写时

优化

  1. 其实放redis初心是将识别出的索引进行映射得到真正的label值,另外,通过ivf查询后数据进一步缩小,那么这时可以利用mysql的快速查询直接得到结果
    先上完整代码
import { Injectable } from '@nestjs/common';
import { CreateFeatureDto } from './dto/create-feature.dto';
import { Feature } from './entities/feature.entity';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, In } from 'typeorm';
import { RedisService } from '../redis/redis.service';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
import * as FormData from 'form-data';
import { Img } from '../img/entities/img.entity';

@Injectable()
export class FeatureService {
  constructor(
    @InjectRepository(Feature)
    private readonly featureRepository: Repository<Feature>,
    @InjectRepository(Img)
    private readonly imgRepository: Repository<Img>,
    private readonly httpService: HttpService,
    private readonly redisService: RedisService,
  ) {
  }

  /**
   * 创建
   * @param file
   * @param createFeatureDto
   * @param needSync //是否需要同步redis,默认为true
   */
  async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {
    const img = this.imgRepository.create({
      img: file.buffer,
    });
    await this.imgRepository.save(img);
    const feature = this.featureRepository.create({
      ...createFeatureDto,
      imgId: img.id,
    });
    const { storeCode } = await this.featureRepository.save(feature);
    needSync && await this.syncRedis(storeCode);
    return feature;
  }

  /**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    const ids = featureDatabase.map(({ id }) => id);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }

  /**
   * 查询所有
   * @param storeCode
   */
  async findAll(storeCode: string) {
    return await this.featureRepository.createQueryBuilder('feature').select(['feature.id', 'feature.features']).where('FIND_IN_SET(:storeCode, feature.storeCode)', { storeCode }).getMany();
  }

  /**
   * 查询特性及其关联的图像
   * @param storeCode
   */
  async findAllWithImage(storeCode: string): Promise<Feature[]> {
    return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').where('FIND_IN_SET(:storeCode, feature.storeCode)', { storeCode }).getMany();
  }

  /**
   * 删除门店所有数据
   * @param storeCode
   */
  async removeAll(storeCode: string): Promise<void> {
    const features = await this.findAllWithImage(storeCode);
    for (const feature of features) {
      await this.remove(feature, storeCode);
    }
    await this.redisService.del(`${storeCode}-featureDatabase`);
  }

  /**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   * @param needList
   */
  async predict(
    file: Express.Multer.File,
    num: string = '5',
    storeCode: string,
    justPredict: string = 'false',
    needList: boolean = false,
  ) {
    const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URL
    const REDIS_KEY_PREFIX = '-featureDatabase';
    const startTime = Date.now();
    const numInt = parseInt(num);
    const isJustPredict = justPredict === 'true';

    try {
      // Prepare form data
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      formData.append('justPredict', justPredict);

      // Send request to Python service
      const response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));
      const { features, index, predictTime } = response.data;

      if (isJustPredict) {
        return { features };
      }

      // Retrieve feature database from Redis
      const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
      if (!featureDatabaseStr) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Parse the Redis result and filter the IDs
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const ids = index
        .map((idx: number) => featureDatabase[idx]);

      if (!ids.length) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Query for features in the database
      const featureList = await this.featureRepository.createQueryBuilder('feature')
        .where('feature.id IN (:...ids)', { ids })
        .orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
        .getMany();

      // Filter to ensure unique labels
      const uniqueList = this.filterUniqueFeatures(featureList, numInt);

      const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);
      return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;
    } catch (error) {
      throw new Error(`Prediction failed: ${error.message}`);
    }
  }

  private filterUniqueFeatures(featureList: any[], limit: number) {
    const uniqueList = [];
    for (const feature of featureList) {
      if (!uniqueList.some(f => f.label === feature.label)) {
        uniqueList.push(feature);
      }
      if (uniqueList.length === limit) break;
    }
    return uniqueList;
  }

  private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {
    const totalTime = `${Date.now() - startTime}ms`;
    return {
      predictTime,
      [`top${num}`]: list.map(({ features, ...rest }) => rest),
      features,
      totalTime,
    };
  }

  /**
   * 计算余弦相似度
   * @param vecA
   * @param vecB
   */
  cosineSimilarity(vecA: number[], vecB: number[]): number {
    if (vecA.length !== vecB.length) {
      throw new Error('Vectors must be of the same length');
    }
    const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);
    const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
    const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
    return dotProduct / (magnitudeA * magnitudeB);
  }

  /**
   * 查找相似
   * @param inputFeatures
   * @param num
   * @param storeCode
   */
  async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{
    label: string;
    similarity: number
  }[]> {
    const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
    if (!featureDatabaseStr) {
      return [];
    }
    const featureDatabase = JSON.parse(featureDatabaseStr);
    const similarities = featureDatabase.map(({ features, label }) => {
      let similarity = 0;
      if (features) {
        similarity = this.cosineSimilarity(inputFeatures, features);
      }
      return { label: label as string, similarity: similarity as number };
    });

    similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);

    const uniqueLabels = new Set<string>();
    const topNUnique: { label: string; similarity: number; }[] = [];
    for (const item of similarities) {
      if (!uniqueLabels.has(item.label as string)) {
        uniqueLabels.add(item.label);
        item.similarity = Math.round(item.similarity * 100) / 100;
        topNUnique.push(item);
        if (topNUnique.length === num) break;
      }
    }
    return topNUnique;
  }

  /**
   * 根据名称查询
   * @param label
   * @param storeCode
   */
  async getByName(label: string, storeCode: string): Promise<Feature[]> {
    return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').where('FIND_IN_SET(:storeCode, feature.storeCode)', { storeCode }).andWhere('feature.label = :label', { label }).getMany();
  }

  /**
   * 根据名称向量个数查询
   * @param label
   * @param storeCode
   */
  async getCountByLabel(label: string, storeCode: string): Promise<number> {
    return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').where('FIND_IN_SET(:storeCode, feature.storeCode)', { storeCode }).andWhere('feature.label = :label', { label }).getCount();
  }

  /**
   * 批量学习
   * @param files
   * @param createFeatureDto
   */
  async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {
    const list = [];
    for (const file of files) {
      try {
        const features = await this.predict(file, '5', createFeatureDto.storeCode, 'true');
        const feature = await this.create(file, {
          ...createFeatureDto,
          //@ts-ignore
          features,
        }, false);
        list.push(feature);
      } catch (e) {
        console.error(e);
      }
    }
    await this.syncRedis(createFeatureDto.storeCode);
    return list;
  }

  /**
   * 删除门店的特征值数据
   * @param feature
   * @param storeCode
   */
  async remove(feature: Feature, storeCode: string) {
    // 分割现有的 tags
    let storeCodeArray = feature.storeCode.split(',');

    // 移除指定的 tag
    storeCodeArray = storeCodeArray.filter(code => code !== storeCode);

    if (storeCodeArray.length === 0) {//当这条数据已经没有门店了,则删除这条数据和图片
      await this.featureRepository.remove(feature);
      await this.imgRepository.remove(feature.img);
    } else {
      // 更新 tags 列
      feature.storeCode = storeCodeArray.join(',');
      await this.featureRepository.save(feature);
    }
  }

  /**
   * 批量删除
   * @param ids
   * @param storeCode
   */
  async batchRemove(ids: string, storeCode: string) {
    const list = ids.split(',');
    for (const id of list) {
      const feature = await this.featureRepository.findOne({
        where: { id: +id },
        relations: ['img'],  // 加载关联的ImageEntity
      });
      feature && await this.remove(feature, storeCode);
    }
    await this.syncRedis(storeCode);
  }

  /**
   * 导入数据
   * @param storeCode
   */
  async importData(storeCode: string) {
    // 查询所有 Feature 实体
    const features = await this.featureRepository.createQueryBuilder('feature')
      .where('NOT FIND_IN_SET(:storeCode, feature.storeCode)', { storeCode })
      .getMany();
    // 遍历每个记录,更新 tags 列
    for (const feature of features) {
      let storeCodeArray = feature.storeCode.split(',');
      storeCodeArray.push(storeCode);
      // 更新 tags 列
      feature.storeCode = storeCodeArray.join(',');
      // 保存更新后的记录
      await this.featureRepository.save(feature);
    }
    await this.syncRedis(storeCode);
    return `同步完成,共导入${features.length}条数据`;
  }

  async init() {
    // 执行原生 SQL 查询
    const result = await this.featureRepository.query(
      'SELECT DISTINCT storeCode FROM feature',
    );
    const did = [];
    for (const row of result) {
      const storeCodes = row.storeCode.split(',');
      for (const storeCode of storeCodes) {
        if (!did.includes(storeCode)) {
          await this.syncRedis(storeCode);
          did.push(storeCode);
        }
      }
    }
    console.log('初始化完成');
  }
}

  1. 初始化方法:在服务启动时同步python端的id和特征值数据,保持和mysql一致
 async init() {
    // 执行原生 SQL 查询
    const result = await this.featureRepository.query(
      'SELECT DISTINCT storeCode FROM feature',
    );
    const did = [];
    for (const row of result) {
      const storeCodes = row.storeCode.split(',');
      for (const storeCode of storeCodes) {
        if (!did.includes(storeCode)) {
          await this.syncRedis(storeCode);
          did.push(storeCode);
        }
      }
    }
    console.log('初始化完成');
  }
  1. 同步方法:redis只储存feature表的相关id值,python端同步所有数据
/**
   * 同步redis
   * @param storeCode
   */
  async syncRedis(storeCode: string) {
    const featureDatabase = await this.findAll(storeCode);
    const ids = featureDatabase.map(({ id }) => id);
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    await firstValueFrom(this.httpService.post(url, { data: featureDatabase, storeCode }));
  }
  1. 经优化后redis的读取和查询已经非常快速了

      // Retrieve feature database from Redis
      const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
      if (!featureDatabaseStr) {
        return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Parse the Redis result and filter the IDs
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const ids = index
        .map((idx: number) => featureDatabase[idx]);
  1. 通过sql查出数据进行返回
  const featureList = await this.featureRepository.createQueryBuilder('feature')
        .where('feature.id IN (:...ids)', { ids })
        .orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
        .getMany();

训练

将图片和label进行批量训练得到特征向量库

import os
import requests

# Define the directory and API URL
base_dir = r'D:\workspace\datasets\geneFruit\train'
api_url = 'http://localhost:3002/feature/batchStudy'


def train_data(store_code):
    # Iterate through each folder (label) and the images within
    for label in os.listdir(base_dir):
        label_path = os.path.join(base_dir, label)
        if os.path.isdir(label_path):
            files = []
            # Add all images from the folder to the 'files' dictionary
            for idx, image_name in enumerate(os.listdir(label_path)):
                image_path = os.path.join(label_path, image_name)
                # Check if the file is an image
                if image_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp', 'gif')):
                    files.append(('files', (image_name, open(image_path, 'rb'), 'image/jpeg')))
            if len(files) != 0:  # Only send if there are images in the folder
                # Prepare the payload
                payload = {
                    'storeCode': store_code,
                    'label': label
                }
                batch_size = 10  # Define how many files to upload per batch
                for i in range(0, len(files), batch_size):
                    batch_files = files[i:i + batch_size]
                    try:
                        requests.post(api_url, files=batch_files, data=payload, timeout=1200)
                    except requests.exceptions.RequestException as e:
                        print(f'Error sending images from folder {label}: {e}')
                        continue
                # Ensure all files are closed after the request
                for _, (file_name, file, mime_type) in files:
                    file.close()
    return f'Successfully sent images from folder {base_dir}'

验证

精确率和召回率

在使用 IVF(Inverted File Index)进行向量检索的场景下,精确率和召回率的计算方法与分类任务类似,但需要根据检索的结果来计算。具体来说,在向量检索中,你通常要找到与查询向量最相似的向量(例如通过最近邻搜索),然后通过以下步骤来计算精确率和召回率。

  1. 概念转换:
    在向量检索的场景中:

查询向量(Query Vector):类似于分类中的“测试样本”。
检索结果:类似于分类模型预测的“正类样本”。
真实结果(Ground Truth):你预期的与查询向量相似的正确向量(即“正类”)。
相关向量:查询向量的真实匹配向量,作为 Ground Truth。
非相关向量:与查询向量无关的向量。

  1. 精确率和召回率的定义
    在这里插入图片描述

  2. 如何计算精确率和召回率

    1. 设定查询和真实结果:
      你需要有一个查询向量,并且知道这个查询向量的真实相关向量集合(Ground Truth)。
    2. 执行检索:
      使用 IVF(或其他向量检索方法)来查询 k 个最近邻的向量。
    3. 判断检索结果的准确性:
      检索出的向量中,与查询向量的 Ground Truth 相关的向量就是正确的结果。
      记录检索出的相关向量和总的相关向量数。
    4. 计算精确率和召回率:
      根据公式,计算精确率和召回率。
  3. 举例说明
    假设你有以下情况:

  • Ground Truth:查询向量的真实相关向量是 [v1, v2, v3, v4, v5],即总共有 5 个相关向量。
  • 检索结果:你的 IVF 向量检索系统返回了前 10 个最相似的向量:[v1, v6, v2, v7, v8, v3, v9, v10, v5, v11]。
    精确率(Precision)
  • 检索出的相关向量是 [v1, v2, v3, v5](共 4 个)。
  • 检索出的总向量数是 10。
    因此:
    Precision= 4/10=0.4
    召回率(Recall)
  • 检索出的相关向量是 [v1, v2, v3, v5](共 4 个)。
  • 真实的相关向量数是 5。
    因此:
    Recall= 4/5=0.8

实际运用

import os
import numpy as np
import requests

# Define the directory and API URL
base_dir = r'D:\workspace\datasets\geneFruit\val'
api_url = 'http://localhost:3002/feature'


def calculate_precision_recall(labels, true_label, true_relevant_count):
    """
    计算单个图片的精确率和召回率
    """
    relevant_found = np.sum(labels == true_label)

    precision = relevant_found / len(labels) if len(labels) > 0 else 0
    recall = relevant_found / true_relevant_count if true_relevant_count > 0 else 0

    return precision, recall


def process_image(image_path, store_code):
    """
    处理单张图片,发送请求并返回结果
    """
    with open(image_path, 'rb') as image_file:
        payload = {
            'storeCode': store_code,
            'needList': True,
        }
        response = requests.post(api_url + '/predict',
                                 files={'file': (os.path.basename(image_path), image_file, 'image/jpeg')},
                                 data=payload)
        if response.status_code == 201:
            data = response.json()
            feature_list = data.get('featureList', [])
            labels = np.array([item['label'] for item in feature_list])
            return labels
        else:
            print(f'Failed to upload {os.path.basename(image_path)}, status code: {response.status_code}')
            return None


def process_label(store_code, label):
    """
    处理每个label,计算该label下所有图片的精确率和召回率
    """
    label_path = os.path.join(base_dir, label)
    if not os.path.isdir(label_path):
        return None

    precision_sum, recall_sum, count = 0, 0, 0

    # 获取真实相关向量个数
    res = requests.get(api_url + f'/getCountByLabel?storeCode={store_code}&label={label}')
    true_relevant_count = res.json()

    for image_name in os.listdir(label_path):  # 遍历每个label下的图片
        image_path = os.path.join(label_path, image_name)
        if image_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp', 'gif')):
            labels = process_image(image_path, store_code)
            if labels is not None:
                precision, recall = calculate_precision_recall(labels, label,
                                                               true_relevant_count)
                precision_sum += precision
                recall_sum += recall
                count += 1

    if count > 0:
        precision_avg = (precision_sum / count) * 100
        recall_avg = (recall_sum / count) * 100
    else:
        precision_avg, recall_avg = 0, 0

    return {'label': label, 'precision': f"{precision_avg:.2f}", 'recall': f"{recall_avg:.2f}"}


def val_data(store_code, target_label=None):
    all_results = []

    labels = [target_label] if target_label else os.listdir(base_dir)

    for label in labels:  # 遍历每个label
        result = process_label(store_code, label)
        if result:
            all_results.append(result)
            print(f'{label} 的精确率: {result["precision"]}%, 召回率: {result["recall"]}%')

    result_str = '\n'.join(
        f'{item["label"]}, 精确率: {item["precision"]}%, 召回率: {item["recall"]}%' for item in all_results
    )
    print(result_str)
    return result_str

下一步

  1. 优化同步速度,目前大约30秒,不是一个生产速度
  2. 实现对特征向量ivf的增删改查
  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值