记录ros下使用yolov5

环境:ros_melodic,ubuntu18.04

首先,我们先捋捋开发思路,我们需要在ros下读取图像,然后对图像使用yolo检测识别。我们了解到ros下图像读取方式如下:

sub_image = rospy.Subscriber("/camera/rgb/image_raw", Image, callback)
def callback(data):
    global image
    image= bridge.imgmsg_to_cv2(data, desired_encoding="bgr8")

其读取图像通过回调函数callback实现,也就是说其中的image其实是一张一张的图片,而不是一整个视频流。从yolo中可以看到其通过如下方式对图像读入:

dataset = LoadStreams(source, img_size=imgsz, stride=stride) if source.isnumeric() else LoadImages(source, img_size=imgsz, stride=stride)  # 加载数据集

在这行代码中,如果是对图像进行处理时,则会执行LoadImage函数来处理,找到该函数后会发现其只对图像进行了如下处理:

 img = letterbox(img0, self.img_size, stride=self.stride)[0]

 # Convert
 img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
 img = np.ascontiguousarray(img)

故此,我们得到我们想要的图像处理,所以我们最后的回调函数应该如下:

def callback(data):
    global image
    image= bridge.imgmsg_to_cv2(data, desired_encoding="bgr8")  # 确保编码正确
    img = letterbox(image, imgsz, stride=stride)[0]

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)

至此我们的图像读取已经完成,然后我们只需要对读取的图像进行yolo检测,我们把检测函数命名为detect()。简化后的yolo检测部分如下:

def detect(img):
    img = torch.from_numpy(img).to(device).float() / 255.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    pred = model(img, augment=augment)[0]  # 模型推理
    pred = non_max_suppression(pred, conf_thres, iou_thres, agnostic=agnostic_nms)  # 进行非最大抑制

    for i, det in enumerate(pred):  # 遍历检测结果
        s= '%g: ' % i
        s += '%gx%g ' % img.shape[2:]
        if len(det):
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], image.shape).round()
            for *xyxy, conf, cls in reversed(det):
                if conf > 0.80:  # 设置置信度阈值
                    label = f'{model.names[int(cls)]} {conf:.2f}'  # 标签
                    plot_one_box(xyxy, image, label=label, color=colors[int(cls)], line_thickness=3)
                    result = model.names[int(cls)]
                    print(result)  # 输出检测结果

最后我们添加yolo的的一些默认处理:

weights = '/home/weights/last.pt'  # 模型权重路径
imgsz = 640  # 图像大小
conf_thres = 0.25  # 物体置信度阈值
iou_thres = 0.45  # NMS 的 IOU 阈值
device = ''
augment = False
agnostic_nms = False

set_logging()
device = select_device(device)
model = attempt_load(weights, map_location=device)  # 加载模型
stride = int(model.stride.max())
imgsz = check_img_size(imgsz, s=stride)
names = model.module.names if hasattr(model, 'module') else model.names  # 获取模型类别名称
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
if device.type != 'cpu':
    model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))

最终成品如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import rospy
import numpy as np
import math
import os
import cv2
from enum import Enum
from std_msgs.msg import UInt8
from sensor_msgs.msg import Image, CompressedImage
from cv_bridge import CvBridge, CvBridgeError
import torch
import time
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import set_logging, check_img_size, non_max_suppression, scale_coords
from utils.plots import plot_one_box
from utils.torch_utils import select_device
from numpy import random
import rospy
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image


weights = '/home/weights/last.pt'  # 模型权重路径
imgsz = 640  # 图像大小
conf_thres = 0.25  # 物体置信度阈值
iou_thres = 0.45  # NMS 的 IOU 阈值
device = ''
augment = False
agnostic_nms = False

set_logging()
device = select_device(device)
model = attempt_load(weights, map_location=device)  # 加载模型
stride = int(model.stride.max())
imgsz = check_img_size(imgsz, s=stride)
names = model.module.names if hasattr(model, 'module') else model.names  # 获取模型类别名称
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
if device.type != 'cpu':
    model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))

def Detect_init():

    global bridge, sub_image, pub_traffic_sign, TrafficSign, pub_image
    bridge = CvBridge()
    sub_image = rospy.Subscriber("/camera/rgb/image_raw", Image, callback)

def callback(data):
    global image
    image= bridge.imgmsg_to_cv2(data, desired_encoding="bgr8")  # 确保编码正确
    img = letterbox(image, imgsz, stride=stride)[0]

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)

    detect(img)

def detect(img):
    img = torch.from_numpy(img).to(device).float() / 255.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    pred = model(img, augment=augment)[0]  # 模型推理
    pred = non_max_suppression(pred, conf_thres, iou_thres, agnostic=agnostic_nms)  # 进行非最大抑制

    for i, det in enumerate(pred):  # 遍历检测结果
        s= '%g: ' % i
        s += '%gx%g ' % img.shape[2:]
        if len(det):
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], image.shape).round()
            for *xyxy, conf, cls in reversed(det):
                if conf > 0.80:  # 设置置信度阈值
                    label = f'{model.names[int(cls)]} {conf:.2f}'  # 标签
                    plot_one_box(xyxy, image, label=label, color=colors[int(cls)], line_thickness=3)
                    result = model.names[int(cls)]
                    print(result)  # 输出检测结果
                    

if __name__ == '__main__':
    try:
        # init ROS node
        rospy.init_node('detect')
        rospy.loginfo("Starting Detect node")
        Detect_init()
        rospy.spin()
    except KeyboardInterrupt:
        print ("Shutting down node.")
        cv2.destroyAllWindows()

其主要麻烦点在于以下两方面:(1)对ros下回调函数中图像的理解(为一张图像,而非视频流)(2)与yolo检测进行的图像对接,由于我们直接采用的单张图像,所以直接使用yolo中的LoadImage函数读取图像就不太好搞

  • 6
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值