基于YoloV5 multi task网络增加分类分支

本文介绍了如何在YoloV5中添加分类分支,以同时实现目标检测和图像分类任务。主要步骤包括修改数据标签,调整数据输入代码,网络结构的改动,模型输出与损失函数的计算,以及最终的可视化结果展示。这种方法允许在共享Backbone的情况下,提高模型的多功能性。
摘要由CSDN通过智能技术生成


前言

YoloV5做目标检测任务,但是当我们需要对图像中的目标进行检测,还需要根据图像整体对该图像做分类时,我们可以在Yolo-Detect检测头前加一个分类分支,这样可以在共享Backbone情况下实现做检测任务,又做分类任务,下面具体介绍增加分类分支的操作方法

一、数据准备

如下图,yolov5的标签 .txt 文件, 只需在第一行新增类别标签(0,1,2…),第二行开始是目标框bbox标签
请添加图片描述

二、数据输入处修改代码

对标签label增加分类分支
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

三、网络修改

在这里插入图片描述
在这里插入图片描述

四、模型输出的和标签接上去,计算loss

在这里插入图片描述
在这里插入图片描述

五、最终可视化结果

原版yolo输出3个检测头,增加后多输出一个分类分支
在这里插入图片描述

总结

简单记录一下,因为yolo版本也在不断更新,说一没有详细说明,简单对主要修改处标记了,如有类似修改需求。大致参照思路即可。不能全抄。

实现基于yolov5框架的人脸识别代码如下: ``` import cv2 import numpy as np import torch from models.experimental import attempt_load from utils.general import non_max_suppression, scale_coords from utils.torch_utils import select_device # 加载模型 model = attempt_load("yolov5s.pt", map_location='cpu') # 设置设备 device = select_device("cpu") # 设置阈值 confidence = 0.4 iou_thresh = 0.5 # 加载类别名称 classes = ["face"] # 开启摄像头 cap = cv2.VideoCapture(0) # 循环读取帧 while True: # 读取帧 ret, frame = cap.read() if not ret: break # 转化为RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 处理大小 img = cv2.resize(frame, (640, 640)) # 转化为张量 img = torch.from_numpy(img).to(device) # 添加维度 img = img.permute(2, 0, 1).unsqueeze(0) # 模型推理 with torch.no_grad(): output = model(img.float(), augment=False)[0] # 处理坐标 boxes = output[..., :4] boxes = scale_coords(img.shape[2:], boxes, frame.shape).round() # 处理置信度和类别 confidence_scores = output[..., 4] class_indexes = output[..., 5].long() # 多类别处理 detections = [] for class_index in range(output.shape[2]-5): class_mask = class_indexes == class_index if not class_mask.any(): continue class_boxes = boxes[class_mask] class_confidence_scores = confidence_scores[class_mask] # 多类别非极大抑制 class_detections = non_max_suppression( torch.cat((class_boxes, class_confidence_scores.unsqueeze(1)), dim=1), conf_thres=confidence, iou_thres=iou_thresh, multi_label=False, classes=None, agnostic=True ) for detection in class_detections: detection = detection.cpu().numpy() detection = detection[0:4].astype(np.int) detections.append(detection) # 画人脸框 for box in detections: x_min, y_min, x_max, y_max = box cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) # 显示结果 cv2.imshow("face detection", frame) # 退出 if cv2.waitKey(1) & 0xFF == ord('q'): break # 释放摄像头 cap.release() # 关闭窗口 cv2.destroyAllWindows() ``` 该代码首先加载yolov5模型,然后设置设备、阈值和类别名称。接下来开启摄像头,循环读取每一帧,并进行模型推理。模型输出包含置信度、类别和坐标信息,需要对其进行处理,得到每一个人脸框的坐标。最后,将每一个人脸框用矩形框住,并显示结果。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值