sahi切片辅助训练推理

本文的目的切片yolov5标注格式的数据,并保存图片和标注文件

代码实现步骤如下

  1. 把yolov5格式转换成coco格式标签;
  2. 切片图片和coco标签;
  3. 把切片出来的coco标签转换回yolov5标签格式
# 1. 把yolov5格式转换成coco格式标签;
# 2. 切片图片和coco标签;
# 3. 把切片出来的coco标签转换回yolov5标签格式

import os
import numpy as np
import cv2
from sahi.utils.coco import Coco, CocoCategory, CocoImage, CocoAnnotation
from sahi.slicing import slice_coco
from sahi.utils.file import save_json

def convert2coco(img_path,h,w,yololabel):

    coco = Coco()
    maps = {
        0 : 'person',
        1 : 'soccer'
    }
    coco.add_category(CocoCategory(id=0, name='person')) # 两个类别
    coco.add_category(CocoCategory(id=1, name='soccer'))

    coco_image = CocoImage(file_name=img_path, height=h, width=w)
    
    for label in yololabel:

        coco_image.add_annotation(
        CocoAnnotation(
            bbox=[label[1], label[2], label[3], label[4]],
            category_id=int(label[0]),
            category_name=maps[label[0]]
        )
        )
    
    coco.add_image(coco_image)
    coco_json = coco.json
    save_json(coco_json, "coco_dataset.json")
    return coco_json

def convert2xywh(l,h,w): # 把(class cx xy w h)转换成左上角wh
    new_l = np.zeros_like(l)
    l[:,1] = l[:,1]*w
    l[:,3] = l[:,3]*w
    l[:,2] = l[:,2]*h
    l[:,4] = l[:,4]*h
    
    new_l[:,0] = l[:,0]
    new_l[:,1] = l[:,1] - l[:,3]/2
    new_l[:,2] = l[:,2] - l[:,4]/2
    new_l[:,3] = l[:,3] 
    new_l[:,4] = l[:,4]
    
    return new_l

def slice_img(save_img_dir):
    
    coco_dict, coco_path = slice_coco(
                coco_annotation_file_path="coco_dataset.json",
                image_dir='',
                slice_height=640,
                slice_width=640,
                overlap_height_ratio=0.2,
                overlap_width_ratio=0.2,
                output_dir = save_img_dir,
                output_coco_annotation_file_name = 'sliced',
                min_area_ratio = 0.2,
                ignore_negative_samples = True
            )
    return  

def convert2yolov5(coco_dir,save_img_dir,save_label_dir):
    
    coco = Coco.from_coco_dict_or_path(coco_dir, save_img_dir)
    # export converted YoloV5 formatted dataset into given output_dir with a 85% train/15% val split
    coco.export_as_yolov5(
    output_dir=save_label_dir,
    disable_symlink = True
    )
    
    return 

if __name__ == '__main__':
    file = 'SNMOT-061'
    img_dir = f'datasets/soccernet/tracking/images/train/{file}/img1/'
    anno_dir = f'datasets/soccernet/tracking/labels/train/{file}/img1/'
    save_img_dir = 'datasets/sliced_soccernet/images/train/' + f'{file}/' # 把切分的图片保存到这里
    save_label_dir = 'datasets/sliced_soccernet/labels/train/' + f'{file}/'
    os.makedirs(save_img_dir,exist_ok=True)
    os.makedirs(save_label_dir,exist_ok=True)
    labels = os.listdir(anno_dir)
    for label in labels:
        if 'old' not in label:
            try:
                os.remove('coco_dataset.json') # 删除中间文件
                os.remove(save_img_dir+'sliced_coco.json')
            except:
                pass
            l = np.loadtxt(anno_dir+label,delimiter=' ') # class cx xy w h
            img_path = img_dir+label.replace('txt','jpg')
            img = cv2.imread(img_path)
            h,w,_ = img.shape
            new_l = convert2xywh(l,h,w)
            coco_json = convert2coco(img_path,h,w,new_l)
            slice_img(save_img_dir)  # 切分图片并保存
            convert2yolov5(save_img_dir+'sliced_coco.json', save_img_dir,save_label_dir) # 把切分完的coco标签转换回yolo格式并保存
            
            
    
            # for ll in new_l: # 验证是否转换正确
            #     if int(ll[0]) == 0:
            #         cv2.rectangle(img,(int(ll[1]),int(ll[2])),(int(ll[1]+ll[3]),int(ll[2]+ll[4])),(255,0,255),2)
            #     else:
            #         cv2.rectangle(img,(int(ll[1]),int(ll[2])),(int(ll[1]+ll[3]),int(ll[2]+ll[4])),(255,0,2),2)         
            # cv2.imwrite('./test.jpg',img)
            
            

在上述的基础上需要修改一下sahi的源码,它默认会保存图片的,注释掉:
在这里插入图片描述

在这里插入图片描述

最终结果
在这里插入图片描述
在这里插入图片描述
转换完成后,可视化代码:

import os
import cv2


# train_lists = os.listdir('datasets/soccernet/tracking/images/train')

# for tl in train_lists:

root = 'datasets/sliced_soccernet/images/train/' + 'SNMOT-061/'
# root2 = 'datasets/soccernet/tracking/images2/train/'+'Vision_State.v4i.yolov8'+'/images/'
sum = 0
train_list = os.listdir(root)

for img in train_list:
    if not img.endswith('jpg'):
        continue
    res = []
    ball_bbox = []
    data = cv2.imread(root+img)
    ih,iw,c = data.shape
    try:
        anno = open(root.replace('images','labels')+img[:-4]+'.txt').read().splitlines()
    except:
        continue
    for an in anno:
        a = an.split(' ')
        cls,x,y,w,h = int(a[0]),float(a[1]),float(a[2]),float(a[3]),float(a[4])
        x = int(iw*x)
        y = int(ih*y)
        w = int(w*iw)
        h = int(h*ih)
        if int(cls)==1:
            cv2.rectangle(data,(x-w//2,y-h//2),(x+w//2,y+h//2),(255,255,0),2)
            ball_bbox.append([x-w//2,y-h//2,w,h])
            pass
        else:
            res.append([x-w//2,y-h//2,x+w//2,y+h//2])
            cv2.rectangle(data,(x-w//2,y-h//2),(x+w//2,y+h//2),(0,0,255),2)
            pass
    # if len(ball_bbox) > 0:
    #     crop_img = data[max(0,ball_bbox[0][1]-50):ball_bbox[0][1]+ball_bbox[0][3]+50,max(0,ball_bbox[0][0]-50):ball_bbox[0][0]+ball_bbox[0][2]+50]

        save_dir = 'outimg/' + root
        os.makedirs(save_dir,exist_ok=True)
        cv2.imwrite(os.path.join(save_dir,img),data)
    

训练完后的切片推理

from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict


yolov8_model_path = 'runs/detect/train13-sahi-n/weights/last.pt'


# 以下为单张切片推理
detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov8',
    model_path=yolov8_model_path,
    confidence_threshold=0.6,
    device="cuda:7", # or 'cuda:0'
)

result = get_sliced_prediction(
    "20231023-134213.jpeg",
    detection_model,
    slice_height = 640,
    slice_width = 640,
    overlap_height_ratio = 0.2,
    overlap_width_ratio = 0.2
)
result.export_visuals(export_dir="sahi/",text_size=1,rect_th=1)
object_prediction_list = result.object_prediction_list


# 以下为batch 切片推理
source_image_dir = "sahi_in/1"
result = predict(
    model_type='yolov8',
    model_path=yolov8_model_path,
    model_device="cuda:7",
    model_confidence_threshold=0.6, 
    source=source_image_dir,
    image_size = (640,640),
    slice_height=960,
    slice_width=1280,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    visual_text_size = 1,
    return_dict = True
)

另外batch推理默认返回的值不包括预测结果,需修改源码:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
SAHI是一种切片辅助推理框架,旨在帮助开发人员解决现实世界中的目标检测问题。它通过将图像分成多个切片来提高检测性能,从而克服了现实世界中的一些问题,例如目标尺寸变化,目标遮挡和目标密度变化等。SAHI的核心思想是将图像分成多个切片,然后对每个切片进行单独的检测,最后将检测结果合并起来得到最终的检测结果。这种方法可以提高检测性能,特别是对于小目标的检测效果更好。 下面是一个使用SAHI进行目标检测的Python代码示例: ```python import cv2 import numpy as np # 加载图像 img = cv2.imread('test.jpg') # 定义切片大小 slice_size = 512 # 获取图像大小 height, width, _ = img.shape # 计算切片数量 num_slices_h = int(np.ceil(height / slice_size)) num_slices_w = int(np.ceil(width / slice_size)) # 定义检测器 detector = cv2.dnn.readNetFromCaffe('deploy.prototxt', 'model.caffemodel') # 定义类别标签 class_labels = ['person', 'car', 'truck', 'bus'] # 定义检测结果列表 results = [] # 循环遍历每个切片 for i in range(num_slices_h): for j in range(num_slices_w): # 计算切片的坐标 x1 = j * slice_size y1 = i * slice_size x2 = min(x1 + slice_size, width) y2 = min(y1 + slice_size, height) # 提取切片 slice_img = img[y1:y2, x1:x2] # 构建输入blob blob = cv2.dnn.blobFromImage(slice_img, 1.0, (300, 300), (104.0, 177.0, 123.0)) # 进行检测 detector.setInput(blob) detections = detector.forward() # 解析检测结果 for k in range(detections.shape[2]): confidence = detections[0, 0, k, 2] class_id = int(detections[0, 0, k, 1]) # 如果置信度大于0.5,则将检测结果添加到列表中 if confidence > 0.5 and class_labels[class_id] == 'person': x = int(detections[0, 0, k, 3] * slice_size) + x1 y = int(detections[0, 0, k, 4] * slice_size) + y1 w = int(detections[0, 0, k, 5] * slice_size) - x h = int(detections[0, 0, k, 6] * slice_size) - y results.append((x, y, w, h)) # 在原始图像上绘制检测结果 for (x, y, w, h) in results: cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) # 显示结果 cv2.imshow('result', img) cv2.waitKey(0) cv2.destroyAllWindows() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值