mmdetection根据结果crop图片并保存

mmdeteciton保存检测结果

初始化模型

import mmcv
import os
import numpy as np
from mmdet.apis import init_detector, inference_detector
from PIL import Image
import matplotlib.pyplot as plt
import cv2

input_dir = ""
out_dir = ""
config_file = ""
checkpoint_file = ""
device = 'cuda:0'

model = init_detector(config_file, checkpoint_file, device=device)
CLASSES = model.CLASSES

if not os.path.exists(out_dir):
    os.mkdir(out_dir)

推理以及结果保存

def dete_result(file):
    img = mmcv.imread(os.path.join(input_dir, file))
    result = inference_detector(model, img)
    return result


def get_expend_box(box,
                   expend_ratio=0.25):
    """
    根据检测结果,按照0.25的比率放大检测框
    Args:
        box:
        expend_ratio:

    Returns:放大的检测框

    """
    height, width = box[3] - box[1], box[2] - box[0]
    box[0] = box[0] - expend_ratio*width
    box[2] = box[2] + expend_ratio*width
    box[1] = box[1] - expend_ratio*height
    box[3] = box[3] + expend_ratio*height

    return box


def get_crop_result(result=None,
                    save_calsses=None,
                    score_thr=0.1,):
    crop_list = []
    bbox_result = result
    bboxes = np.vstack(bbox_result)
    labels = [
                np.full(bbox.shape[0], i, dtype=np.int32)
                for i, bbox in enumerate(bbox_result)
            ]
    labels = np.concatenate(labels)
    if score_thr > 0:
        scores = bboxes[:, -1]
        inds = scores > score_thr
        bboxes = bboxes[inds, :]
        labels = labels[inds]

    positions = bboxes[:, :4].astype(np.int32)
    for i, (pos, label) in enumerate(zip(positions, labels)):
        label_text = CLASSES[
                label] if CLASSES is not None else f'class {label}'
        if label_text == save_calsses:
            crop_list.append(pos)

    return crop_list

主函数

def main():
    files = os.listdir(input_dir)
    if len(files) != 0:
        for file in files:
            img = Image.open(os.path.join(input_dir, file))
            name = os.path.splitext(file)[0]
            print('detecting: ' + name)
            result = dete_result(file)
            crop_list = get_crop_result(result=result, score_thr=0.7, save_calss='Car')
            for i, bbox in enumerate(crop_list):
                bbox = get_expend_box(box=bbox)
                img_crop = img.crop(bbox)
                path = os.path.join(out_dir, name+str(i) + ".jpg")
                img_crop.save(path)
                # plt.imshow(img_crop)
                # plt.show()
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值