基于PaddleDetection目标检测labelme标注自动获取

在百度的PaddleDetection项目的基础上实现目标检测labelme标注的自动获取,需要先训练一个模型,然后通过这个模型去标注,最后用labelme进行微调

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import json
import io
import base64

# add python path of PadleDetection to sys.path
from ppdet.data.source.category import get_categories
from ppdet.optimizer import ModelEMA
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.utils.visualizer import visualize_results

parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

import warnings
warnings.filterwarnings('ignore')
import glob
import paddle
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_npu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.slim import build_slim_model
from ppdet.metrics import  get_infer_results
from PIL import Image, ImageOps

import numpy as np
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')


def parse_args():
    parser = ArgsParser()
    parser.add_argument(
        "--infer_dir",
        type=str,
        default=r"C:\Users\86187\Desktop\classifyUnrecognizedESAll",
        help="Directory for images to perform inference on.")
    parser.add_argument(
        "--infer_img",
        type=str,
        default=None,
        help="Image path, has higher priority over --infer_dir")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="python_infer_output",
        help="Directory for storing the output visualization files.")
    parser.add_argument(
        "--draw_threshold",
        type=float,
        default=0.5,
        help="Threshold to reserve the result for visualization.")
    parser.add_argument(
        "--slim_config",
        default=None,
        type=str,
        help="Configuration file of slim method.")
    parser.add_argument(
        "--use_vdl",
        type=bool,
        default=False,
        help="Whether to record the data to VisualDL.")
    parser.add_argument(
        '--vdl_log_dir',
        type=str,
        default="vdl_log_dir/image",
        help='VisualDL logging directory for image.')
    parser.add_argument(
        "--save_txt",
        type=bool,
        default=False,
        help="Whether to save inference result in txt.")
    args = parser.parse_args()
    return args


def img_arr_to_b64(img_pil):
    # img_pil = Image.fromarray(img_arr)
    f = io.BytesIO()
    img_pil.save(f, format="PNG")
    img_bin = f.getvalue()
    if hasattr(base64, "encodebytes"):
        img_b64 = base64.encodebytes(img_bin)
    else:
        img_b64 = base64.encodestring(img_bin)
    return img_b64


def _get_save_image_name(output_dir, image_path):
    """
    Get save image name from source image path.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    image_name = os.path.split(image_path)[-1]
    name, ext = os.path.splitext(image_name)
    return os.path.join(output_dir, "{}".format(name)) + ext


def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
        return [infer_img]

    images = set()
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    logger.info("Found {} inference images in total.".format(len(images)))

    return images


def _coco17_category():
    """
    Get class id to category id map and category id
    to category name map of COCO2017 dataset

    """
    clsid2catid = {
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 5,
        6: 6,
        7: 7,
        8: 8,
        9: 9,
        10: 10,
        11: 11,
        12: 13,
        13: 14,
        14: 15,
        15: 16,
        16: 17,
        17: 18,
        18: 19,
        19: 20,
        20: 21,
        21: 22,
        22: 23,
        23: 24,
        24: 25,
        25: 27,
        26: 28,
        27: 31,
        28: 32,
        29: 33,
        30: 34,
        31: 35,
        32: 36,
        33: 37,
        34: 38,
        35: 39,
        36: 40,
        37: 41,
        38: 42,
        39: 43,
        40: 44,
        41: 46,
        42: 47,
        43: 48,
        44: 49,
        45: 50,
        46: 51,
        47: 52,
        48: 53,
        49: 54,
        50: 55,
        51: 56,
        52: 57,
        53: 58,
        54: 59,
        55: 60,
        56: 61,
        57: 62,
        58: 63,
        59: 64,
        60: 65,
        61: 67,
        62: 70,
        63: 72,
        64: 73,
        65: 74,
        66: 75,
        67: 76,
        68: 77,
        69: 78,
        70: 79,
        71: 80,
        72: 81,
        73: 82,
        74: 84,
        75: 85,
        76: 86,
        77: 87,
        78: 88,
        79: 89,
        80: 90
    }

    catid2name = {
        0: 'background',
        1: 'barCode',
        2: '5',
        3: '4',
        4: '8',
        5: '6',
        6: '9',
        7: '3',
        8: '0',
        9: '2',
        10: '1',
        11: 'fire hydrant',
        13: 'stop sign',
        14: 'parking meter',
        15: 'bench',
        16: 'bird',
        17: 'cat',
        18: 'dog',
        19: 'horse',
        20: 'sheep',
        21: 'cow',
        22: 'elephant',
        23: 'bear',
        24: 'zebra',
        25: 'giraffe',
        27: 'backpack',
        28: 'umbrella',
        31: 'handbag',
        32: 'tie',
        33: 'suitcase',
        34: 'frisbee',
        35: 'skis',
        36: 'snowboard',
        37: 'sports ball',
        38: 'kite',
        39: 'baseball bat',
        40: 'baseball glove',
        41: 'skateboard',
        42: 'surfboard',
        43: 'tennis racket',
        44: 'bottle',
        46: 'wine glass',
        47: 'cup',
        48: 'fork',
        49: 'knife',
        50: 'spoon',
        51: 'bowl',
        52: 'banana',
        53: 'apple',
        54: 'sandwich',
        55: 'orange',
        56: 'broccoli',
        57: 'carrot',
        58: 'hot dog',
        59: 'pizza',
        60: 'donut',
        61: 'cake',
        62: 'chair',
        63: 'couch',
        64: 'potted plant',
        65: 'bed',
        67: 'dining table',
        70: 'toilet',
        72: 'tv',
        73: 'laptop',
        74: 'mouse',
        75: 'remote',
        76: 'keyboard',
        77: 'cell phone',
        78: 'microwave',
        79: 'oven',
        80: 'toaster',
        81: 'sink',
        82: 'refrigerator',
        84: 'book',
        85: 'clock',
        86: 'vase',
        87: 'scissors',
        88: 'teddy bear',
        89: 'hair drier',
        90: 'toothbrush'
    }

    clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
    catid2name.pop(0)

    return clsid2catid, catid2name


def getImagesLabels(image, bboxes, im_id, save_image_name, catid2name, draw_threshold):

    name = save_image_name.split(".")
    save_json_name = name[0] + '.json'

    w, h = image.size

    # label的写入
    dst_data = dict()
    dst_data['version'] = "4.5.9"
    dst_data['flags'] = {}
    dst_data['shapes'] = []
    dst_data['imagePath'] = save_image_name.split("\\")[-1]
    dst_data['imageData'] = img_arr_to_b64(image).decode('utf-8')
    dst_data['imageHeight'] = h
    dst_data['imageWidth'] = w

    for dt in np.array(bboxes):
        if im_id != dt['image_id']:
            continue

        catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
        if score < draw_threshold:
            continue

        xmin, ymin, w, h = bbox
        label = catid2name[catid]

        shape = dict()
        shape['label'] = label
        shape['points'] = [[xmin, ymin], [xmin + w, ymin + h]]
        shape['group_id'] = None
        shape['shape_type'] = "rectangle"
        shape['flags'] = {}
        dst_data['shapes'].append(shape)
    print(save_json_name)
    json.dump(dst_data, open(save_json_name, 'w'), indent=4)


def run(FLAGS, cfg):
    draw_threshold = 0.5
    output_dir = 'output'
    labels_dir = 'labelme'
    # build data loader
    mode = 'test'
    dataset = cfg['{}Dataset'.format(mode.capitalize())]

    # get inference images
    images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)

    # build model
    model = create(cfg.architecture)

    # normalize params for deploy
    model.load_meanstd(cfg['TestReader']['sample_transforms'])

    use_ema = ('use_ema' in cfg and cfg['use_ema'])
    if use_ema:
        ema_decay = cfg.get('ema_decay', 0.9998)
        cycle_epoch = cfg.get('cycle_epoch', -1)
        ema = ModelEMA(
            model,
            decay=ema_decay,
            use_thres_step=True,
            cycle_epoch=cycle_epoch)

    # load weights
    load_pretrain_weight(model, cfg.weights)

    # # predict
    dataset.set_images(images)
    loader = create('TestReader')(dataset, 0)

    imid2path = dataset.get_imid2path()

    anno_file = dataset.get_anno()
    clsid2catid, catid2name = _coco17_category()
    # print(clsid2catid)
    # print(catid2name)

    # Run Infer
    model.eval()
    status = {'mode': 'test'}
    results = []
    for step_id, data in enumerate(loader):
        status['step_id'] = step_id
        # forward
        outs = model(data)

        # print(outs)

        for key in ['im_shape', 'scale_factor', 'im_id']:
            outs[key] = data[key]
        for key, value in outs.items():
            if hasattr(value, 'numpy'):  # hasattr() 函数用于判断对象是否包含对应的属性。hasattr(object, name)
                outs[key] = value.numpy()
        results.append(outs)
        # sniper

    for outs in results:
        batch_res = get_infer_results(outs, clsid2catid)
        bbox_num = outs['bbox_num']
        start = 0
        for i, im_id in enumerate(outs['im_id']):
            image_path = imid2path[int(im_id)]
            image = Image.open(image_path).convert('RGB')
            image = ImageOps.exif_transpose(image)
            status['original_image'] = np.array(image.copy())
            end = start + bbox_num[i]
            bbox_res = batch_res['bbox'][start:end] \
                if 'bbox' in batch_res else None

            save_label_name = _get_save_image_name(labels_dir, image_path)

            # 通过检测结果生成标签
            getImagesLabels(image, bbox_res, int(im_id), save_label_name, catid2name, draw_threshold)

            # 可视化检测结果
            image = visualize_results(
                image, bbox_res, None, None, None,
                int(im_id), catid2name, draw_threshold)
            status['result_image'] = np.array(image.copy())
            # save image with detection
            save_name = _get_save_image_name(output_dir, image_path)
            # print("save_name: ", save_name)
            logger.info("Detection bbox results save in {}".format(
                save_name))
            image.save(save_name, quality=95)
            start = end


def main():
    FLAGS = parse_args()
    cfg = load_config(FLAGS.config)
    cfg['use_vdl'] = FLAGS.use_vdl
    cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
    merge_config(FLAGS.opt)

    # disable npu in config by default
    if 'use_npu' not in cfg:
        cfg.use_npu = False

    if cfg.use_gpu:
        place = paddle.set_device('gpu')
    elif cfg.use_npu:
        place = paddle.set_device('npu')
    else:
        place = paddle.set_device('cpu')

    if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
        cfg['norm_type'] = 'bn'

    if FLAGS.slim_config:
        cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')

    check_config(cfg)
    check_gpu(cfg.use_gpu)
    check_npu(cfg.use_npu)
    check_version()

    run(FLAGS, cfg)


if __name__ == '__main__':
    main()

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值