使用级联检测策略准确识别猫和猫眼

在动物识别的图像处理中,尤其是在图片中同时出现多种动物的情况下,准确地识别特定动物的特定部位(如猫的眼睛)并不是一件容易的事。常规的对象识别模型在面对复杂的场景时很容易产生误报,例如可能将狗的眼睛错误识别为猫的眼睛。为了提高准确性并减少这类误报,我们可以采用一种称为“级联检测”的方法。

为什么需要级联检测?

级联检测通过分步骤使用两个或更多专注于不同任务的模型来提高整体的识别精度。简单来说,就像是先用一个放大镜粗略找到猫在哪里,再用另一个放大镜仔细观察猫的眼睛在哪里。这种方法的优点在于:

  • 减少误报: 第一个模型确保我们只在猫的区域寻找猫眼,避免了在狗或其他动物上错误标记猫眼的情况。
  • 提高精度: 专门用于猫眼的模型可以更精确地定位和识别猫眼,因为它只关注已经被第一个模型识别为猫的区域。

示例:用YOLOv5识别猫和猫眼

假设我们有两个预训练的YOLOv5模型:一个专门用来识别猫,另一个专门用来识别猫眼。下面是如何实现这一级联识别的示例代码:

import os
import torch
import numpy as np
import cv2
import copy
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh
from utils.augmentations import letterbox
from models.experimental import attempt_load
from utils.torch_utils import select_device

class CatEyeDetector:
    def __init__(self, model_path_cat, model_path_eye):
        """初始化猫和猫眼检测器
        
        Args:
            model_path_cat: 猫检测模型路径
            model_path_eye: 猫眼检测模型路径
        """
        # 通用设置
        self.device = select_device('0')  # 使用GPU 0,如需CPU使用''
        self.conf_thres = 0.4  # 置信度阈值
        self.iou_thres = 0.45  # NMS IoU阈值
        self.classes = None  # 不限制类别
        self.agnostic_nms = False  # 不使用类别无关NMS
        self.max_det = 1000  # 最大检测数量
        self.augment = False  # 不使用增强推理
        
        # 猫模型相关设置
        imgsz_cat = 640  # 输入尺寸
        self.model_cat = attempt_load(model_path_cat, self.device)
        self.stride_cat = int(self.model_cat.stride.max())
        self.imgsz_cat = check_img_size(imgsz_cat, s=self.stride_cat)
        self.names_cat = self.model_cat.names  # 获取类别名称
        if isinstance(self.names_cat, dict):
            self.names_cat = list(self.names_cat.values())
            
        # 猫眼模型相关设置
        imgsz_eye = 640
        self.model_eye = attempt_load(model_path_eye, self.device)
        self.stride_eye = int(self.model_eye.stride.max())
        self.imgsz_eye = check_img_size(imgsz_eye, s=self.stride_eye)
        self.names_eye = self.model_eye.names
        if isinstance(self.names_eye, dict):
            self.names_eye = list(self.names_eye.values())
            
        # 合并类别列表
        self.class_list = self.names_eye + self.names_cat
        
        # 结果后处理参数
        self.final_nms_iou_thres = 0.3  # 最终NMS的IoU阈值

    def _preprocess(self, img0, imgsz, stride):
        """预处理图像
        
        Args:
            img0: 原始图像
            imgsz: 调整大小的目标尺寸
            stride: 模型的stride
            
        Returns:
            预处理后的图像张量
        """
        img = letterbox(img0, imgsz, stride=stride, auto=True)[0]  # 调整大小并填充
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)  # 确保内存连续
        img = torch.from_numpy(img).to(self.device)
        img = img.float()  # uint8 to fp16/32
        img /= 255.0  # 0-255 to 0.0-1.0
        if len(img.shape) == 3:
            img = img[None]  # 扩展batch维度
        return img

    @torch.no_grad()
    def _detect(self, img, model):
        """使用模型进行检测
        
        Args:
            img: 预处理后的图像张量
            model: 检测模型
            
        Returns:
            检测结果
        """
        pred = model(img, augment=self.augment)[0]
        return pred

    def _process_cat_detection(self, pred_cat, img_cat, img0_cat):
        """处理猫检测结果
        
        Args:
            pred_cat: 猫检测原始预测结果
            img_cat: 输入网络的调整大小后的猫图像
            img0_cat: 原始图像
            
        Returns:
            处理后的猫检测结果、猫区域列表
        """
        # NMS处理
        pred_cat_nms = non_max_suppression(pred_cat, self.conf_thres, self.iou_thres, 
                                          self.classes, self.agnostic_nms, max_det=self.max_det)
        
        det_cat = []
        cat_boxes = []
        
        for i, det in enumerate(pred_cat_nms):  # 每张图像的检测结果
            if len(det):
                # 将检测坐标从调整大小后的图像映射回原始图像
                det[:, :4] = scale_coords(img_cat.shape[2:], det[:, :4], img0_cat.shape).round()
                
                # 只保留猫类别(假设猫的类别索引为0)
                cats = det[det[:, 5] == 0]
                
                for *xyxy, conf, cls in cats:
                    x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
                    cat_boxes.append([x1, y1, x2, y2])
                
                det_cat = cats
                
        return det_cat, cat_boxes

    def _crop_cat_areas(self, img0_cat, cat_boxes):
        """裁剪猫区域用于猫眼检测
        
        Args:
            img0_cat: 原始图像
            cat_boxes: 猫的边界框列表
            
        Returns:
            裁剪区域列表和对应的边界框信息
        """
        img_crops = []
        boxes_info = []
        
        for box in cat_boxes:
            x1, y1, x2, y2 = box
            # 直接裁剪,不进行外扩
            img_crop = img0_cat[y1:y2, x1:x2]
            img_crops.append(img_crop)
            boxes_info.append(box)
            
        return img_crops, boxes_info

    def _process_eye_detection(self, pred_eye, box_info, img_eye, img_crop):
        """处理猫眼检测结果并映射坐标
        
        Args:
            pred_eye: 猫眼检测原始预测结果
            box_info: 猫区域边界框信息
            img_eye: 输入网络的猫眼图像
            img_crop: 裁剪的猫区域图像
            
        Returns:
            处理后的猫眼检测结果
        """
        # NMS处理
        pred_eye_nms = non_max_suppression(pred_eye, self.conf_thres, self.iou_thres, 
                                          self.classes, self.agnostic_nms, max_det=self.max_det)
        
        det_eye = []
        
        for i, det in enumerate(pred_eye_nms):  # 每张图像的检测结果
            if len(det):
                # 将检测坐标从调整大小后的图像映射回裁剪图像
                det[:, :4] = scale_coords(img_eye.shape[2:], det[:, :4], img_crop.shape).round()
                
                # 坐标映射回原图(核心步骤)
                x1, y1, _, _ = box_info
                det[:, 0] += x1  # x坐标加上裁剪区域左上角x坐标
                det[:, 2] += x1
                det[:, 1] += y1  # y坐标加上裁剪区域左上角y坐标
                det[:, 3] += y1
                
                det_eye = det
                
        return det_eye

    def _final_nms(self, det_all):
        """对所有检测结果进行最终NMS
        
        Args:
            det_all: 所有检测结果
            
        Returns:
            NMS后的最终结果
        """
        # 筛选有效检测结果
        valid_dets = det_all[det_all[:, 2] > 0]
        
        # 按类别分组进行NMS
        cat_dets = valid_dets[valid_dets[:, 5] == 0]  # 猫检测结果(类别0)
        eye_dets = valid_dets[valid_dets[:, 5] == 1]  # 猫眼检测结果(类别1)
        
        # 对猫眼检测结果进行NMS
        if len(eye_dets) > 1:
            boxes, scores = eye_dets[:, :4], eye_dets[:, 4]
            i = torchvision.ops.nms(boxes, scores, self.final_nms_iou_thres)
            eye_dets = eye_dets[i]
            
        # 合并结果
        final_dets = torch.cat([cat_dets, eye_dets], dim=0)
        return final_dets

    def detect(self, img0):
        """执行级联检测
        
        Args:
            img0: 输入图像
            
        Returns:
            最终检测结果,图像
        """
        # 初始化结果tensor
        det_all = torch.zeros((self.max_det, 6)).to(self.device)
        det_count = 0  # 结果计数
        
        # 第一阶段:猫检测
        img_cat = self._preprocess(img0, self.imgsz_cat, self.stride_cat)
        pred_cat = self._detect(img_cat, self.model_cat)
        det_cat, cat_boxes = self._process_cat_detection(pred_cat, img_cat, img0)
        
        # 将猫检测结果添加到总结果中
        if len(det_cat) > 0:
            num_cats = det_cat.shape[0]
            det_all[:num_cats] = det_cat
            det_count += num_cats
        
        # 第二阶段:猫眼检测(只在检测到猫的区域)
        img_crops, boxes_info = self._crop_cat_areas(img0, cat_boxes)
        
        for img_crop, box_info in zip(img_crops, boxes_info):
            if img_crop.size == 0 or img_crop is None:
                continue
                
            # 预处理裁剪图像
            img_eye = self._preprocess(img_crop, self.imgsz_eye, self.stride_eye)
            
            # 检测猫眼
            pred_eye = self._detect(img_eye, self.model_eye)
            
            # 处理猫眼检测结果
            det_eye = self._process_eye_detection(pred_eye, box_info, img_eye, img_crop)
            
            # 添加到总结果
            if len(det_eye) > 0:
                num_eyes = det_eye.shape[0]
                det_all[det_count:det_count + num_eyes] = det_eye
                det_count += num_eyes
        
        # 最终NMS
        det_all = self._final_nms(det_all)
        
        return det_all, img0

    def show_result(self, img0):
        """格式化输出检测结果
        
        Args:
            img0: 输入图像
            
        Returns:
            检测结果字典列表
        """
        det_all, img0 = self.detect(img0)
        detect_res = []
        
        for *xyxy, conf, cls in reversed(det_all):
            # 转换为中心点坐标和宽高格式
            xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4))).view(-1).tolist()
            
            x_mid, y_mid, width, height = int(xywh[0]), int(xywh[1]), int(xywh[2]), int(xywh[3])
            x_min = x_mid - width // 2
            y_min = y_mid - height // 2
            
            # 获取类别标签
            label = "cat" if int(cls) == 0 else "cat_eye"
            
            # 生成结果字典
            result = {
                'score': round(float(conf), 2),
                'tag': label,
                'warning': 1,
                'frame': {
                    'x': x_min,
                    'y': y_min,
                    'width': width,
                    'height': height,
                }
            }
            
            detect_res.append(result)
            
        return detect_res

    def visualize_results(self, img0, results, save_path=None):
        """可视化检测结果
        
        Args:
            img0: 原始图像
            results: 检测结果列表
            save_path: 可选,保存路径
        """
        # 创建图像副本
        img_vis = img0.copy()
        
        # 绘制检测框
        for r in results:
            tag = r['tag']
            score = r['score']
            x = r['frame']['x']
            y = r['frame']['y']
            width = r['frame']['width']
            height = r['frame']['height']
            
            # 为不同类别使用不同颜色
            color = (0, 255, 0) if tag == 'cat' else (0, 0, 255)  # 猫使用绿色,猫眼使用红色
            
            # 绘制边界框
            cv2.rectangle(img_vis, (x, y), (x + width, y + height), color, 2)
            
            # 绘制标签
            label = f"{tag} {score:.2f}"
            t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
            cv2.rectangle(img_vis, (x, y - t_size[1] - 5), (x + t_size[0], y), color, -1)
            cv2.putText(img_vis, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        # 保存或显示图像
        if save_path:
            cv2.imwrite(save_path, img_vis)
        else:
            cv2.imshow('Result', img_vis)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        
        return img_vis


if __name__ == "__main__":
    # 模型路径
    model_path_cat = "weights/cat_detector.pt"
    model_path_eye = "weights/cat_eye_detector.pt"
    
    # 初始化检测器
    detector = CatEyeDetector(model_path_cat, model_path_eye)
    
    # 测试图像路径
    img_path = "test_images/cat.jpg"
    
    # 读取图像
    img0 = cv2.imread(img_path)
    
    # 执行检测
    results = detector.show_result(img0)
    print(results)
    
    # 可视化结果
    detector.visualize_results(img0, results, "result.jpg")

代码详解

我们的级联检测器实现了以下关键步骤:

  • 1. 初始化两个YOLOv5模型:
    一个用于检测猫的模型(第一级检测)
    一个用于检测猫眼的模型(第二级检测)
  • 2. 级联检测流程:
    首先使用第一个模型在整个图像中检测猫
    对每个检测到的猫区域进行裁剪
    在裁剪的区域内使用第二个模型检测猫眼
    将猫眼坐标从裁剪图像映射回原始图像
  • 3. 坐标映射机制:
    将在裁剪图像中检测到的猫眼坐标加上裁剪区域的左上角坐标,得到猫眼在原始图像中的位置
    这是级联检测中的关键步骤,确保了最终结果在原图中的准确定位
  • 4. 后处理:
    对所有检测结果(猫和猫眼)进行最终的NMS处理,消除冗余检测
    格式化输出结果,便于后续处理和可视化
  • 5. 可视化:
    提供可视化功能,用不同颜色标注猫和猫眼的位置
    在图像上显示类别标签和置信度

结论

通过使用级联YOLOv5模型,我们可以显著减少在复杂环境中的误报,同时提高猫眼检测的精确度。这种方法的优势在于它将复杂问题分解为两个更简单的子问题,先识别大目标(猫),再在大目标区域内精确定位小目标(猫眼)。

这种级联检测思路不仅适用于猫眼检测,也可以应用于其他类似的场景,例如人脸-人眼检测、车辆-车牌检测等。在工业应用中,这种方法也可以用于大金具-小金具的检测,显著提高检测精度和效率。

通过这种方法,我们可以有效地简化问题,逐步解决,最终获得更可靠的结果,大幅提升图像分析应用的准确性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值