label-studio半自动化标注,后端部署自定义模型

后端主程序代码

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import logging
import logging.config
import os

logging.config.dictConfig({
    'version': 1,
    'formatters': {
        'standard': {
            'format':
            '[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s'  # noqa E501
        }
    },
    'handlers': {
        'console': {
            'class': 'logging.StreamHandler',
            'level': 'DEBUG',
            'stream': 'ext://sys.stdout',
            'formatter': 'standard'
        }
    },
    'root': {
        'level': 'ERROR',
        'handlers': ['console'],
        'propagate': True
    }
})

_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')


def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
    if not os.path.exists(config_path):
        return dict()
    with open(config_path) as f:
        config = json.load(f)
    assert isinstance(config, dict)
    return config


if __name__ == '__main__':

    from label_studio_ml.api import init_app
    from label_studio_ml_model import MMDetection

    parser = argparse.ArgumentParser(description='Label studio')
    parser.add_argument(
        '-p',
        '--port',
        dest='port',
        type=int,
        default=9090,
        help='Server port')
    parser.add_argument(
        '--host', dest='host', type=str, default='0.0.0.0', help='Server host')
    parser.add_argument(
        '--kwargs',
        '--with',
        dest='kwargs',
        metavar='KEY=VAL',
        nargs='+',
        type=lambda kv: kv.split('='),
        help='Additional LabelStudioMLBase model initialization kwargs')
    parser.add_argument(
        '-d',
        '--debug',
        dest='debug',
        action='store_true',
        help='Switch debug mode')
    parser.add_argument(
        '--log-level',
        dest='log_level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
        default=None,
        help='Logging level')
    parser.add_argument(
        '--model-dir',
        dest='model_dir',
        default=None,
        help='Directory models are store',
    )
    parser.add_argument(
        '--check',
        dest='check',
        action='store_true',
        help='Validate model instance before launching server')

    args = parser.parse_args()

    # setup logging level
    if args.log_level:
        logging.root.setLevel(args.log_level)

    def isfloat(value):
        try:
            float(value)
            return True
        except ValueError:
            return False

    def parse_kwargs():
        param = dict()
        for k, v in args.kwargs:
            if v.isdigit():
                param[k] = int(v)
            elif v == 'True' or v == 'true':
                param[k] = True
            elif v == 'False' or v == 'False':
                param[k] = False
            elif isfloat(v):
                param[k] = float(v)
            else:
                param[k] = v
        return param

    kwargs = get_kwargs_from_config()

    if args.kwargs:
        kwargs.update(parse_kwargs())

    if args.check:
        print('Check "' + MMDetection.__name__ + '" instance creation..')
        model = MMDetection(**kwargs)
    app = init_app(
        model_class=MMDetection,
        model_dir=os.environ.get('MODEL_DIR', args.model_dir),
        redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
        redis_host=os.environ.get('REDIS_HOST', 'localhost'),
        redis_port=os.environ.get('REDIS_PORT', 6379),
        **kwargs)

    app.run(host=args.host, port=args.port, debug=args.debug)

在编写后端模型代码,切记模型要先进行加载,放在class类外面,这样在预测的时候就不用每次都进行加载了

# Copyright (c) OpenMMLab. All rights reserved.
import io
import json
import logging
import os
from urllib.parse import urlparse

import boto3
from botocore.exceptions import ClientError
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size,
                                   get_single_tag_keys)
from label_studio_tools.core.utils.io import get_data_dir

from mmdet.apis import inference_detector, init_detector
import onnxruntime
from onnx_PredictConfig import PredictConfig
from preprocess_img import Compose
import cv2
import numpy as np
logger = logging.getLogger(__name__)


# def load_my_model(device="cuda:0"):
#     """
#     Loads the Segment Anything model on initializing Label studio, so if you call it outside MyModel it doesn't load every time you try to make a prediction
#     Returns the predictor object. For more, look at Facebook's SAM docs
#     """
#     checkpoint_file = 'C:\\Users\\Administrator\\PycharmProjects\\onnx_eval\\rtdetr.onnx'
#     providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device != 'cpu' else ['CPUExecutionProvider']
#     model = onnxruntime.InferenceSession(checkpoint_file, providers=providers)
#     print('jjjjjjjjjjjjjjjjj')
#     return model


device = "cuda:0"
checkpoint_file = 'C:\\Users\\Administrator\\PycharmProjects\\onnx_eval\\rtdetr.onnx'
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device != 'cpu' else ['CPUExecutionProvider']
model_pp = onnxruntime.InferenceSession(checkpoint_file, providers=providers)
print('提前加载模型,以后预测不需要每次加载')


class MMDetection(LabelStudioMLBase):
    """Object detector based on https://github.com/open-mmlab/mmdetection."""

    def __init__(self,
                 config_file=None,
                 checkpoint_file=None,
                 image_dir=None,
                 labels_file=None,
                 score_threshold=0.5,
                 device='cpu',
                 **kwargs):

        super(MMDetection, self).__init__(**kwargs)
        self.config_file = 'C:\\Users\\Administrator\\PycharmProjects\\onnx_eval\\infer_cfg.yml'
        self.checkpoint_file = 'C:\\Users\\Administrator\\PycharmProjects\\onnx_eval\\rtdetr.onnx'
        self.labels_file = labels_file
        # default Label Studio image upload folder
        upload_dir = os.path.join(get_data_dir(), 'media', 'upload')
        self.image_dir = image_dir or upload_dir
        logger.debug(
            f'{self.__class__.__name__} reads images from {self.image_dir}')
        if self.labels_file and os.path.exists(self.labels_file):
            self.label_map = json_load(self.labels_file)
        else:
            self.label_map = {}

        self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys(  # noqa E501
            self.parsed_label_config, 'RectangleLabels', 'Image')
        schema = list(self.parsed_label_config.values())[0]
        self.labels_in_config = set(self.labels_in_config)

        # Collect label maps from `predicted_values="airplane,car"` attribute in <Label> tag # noqa E501
        self.labels_attrs = schema.get('labels_attrs')
        if self.labels_attrs:
            for label_name, label_attrs in self.labels_attrs.items():
                for predicted_value in label_attrs.get('predicted_values',
                                                       '').split(','):
                    self.label_map[predicted_value] = label_name

        print('Load new model from: ', self.config_file, self.checkpoint_file)
        # PREDICTOR=load_my_model(device)
        self.model = model_pp
        # providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device != 'cpu' else ['CPUExecutionProvider']
        # self.model = onnxruntime.InferenceSession(self.checkpoint_file, providers=providers)
        self.score_thresh = score_threshold

        self.infer_config = PredictConfig(self.config_file)
        self.transforms = Compose(self.infer_config.preprocess_infos)

    def _get_image_url(self, task):
        image_url = task['data'].get(
            self.value) or task['data'].get(DATA_UNDEFINED_NAME)
        if image_url.startswith('s3://'):
            # presign s3 url
            r = urlparse(image_url, allow_fragments=False)
            bucket_name = r.netloc
            key = r.path.lstrip('/')
            client = boto3.client('s3')
            try:
                image_url = client.generate_presigned_url(
                    ClientMethod='get_object',
                    Params={
                        'Bucket': bucket_name,
                        'Key': key
                    })
            except ClientError as exc:
                logger.warning(
                    f'Can\'t generate presigned URL for {image_url}. Reason: {exc}'  # noqa E501
                )
        return image_url

    def predict(self, tasks, **kwargs):
        assert len(tasks) == 1
        task = tasks[0]
        image_url = self._get_image_url(task)
        image_path = self.get_local_path(image_url)
        # print(image_path)

        data = cv2.imread(image_path)

        inputs = self.transforms(data)
        inputs_name = [var.name for var in self.model.get_inputs()]
        inputs = {k: inputs[k][None,] for k in inputs_name}

        model_results = self.model.run(output_names=None, input_feed=inputs)

        results = []
        all_scores = []
        img_width, img_height = get_image_size(image_path)
        # print(f'>>> model_results: {model_results}')
        print(f'>>> label_map {self.label_map}')

        classes = ['tower']
        print(f'Classes >>> {classes}')



        bboxes = np.array(model_results[0])


        for bbox in bboxes:
            if bbox[0] > -1 and bbox[1] > self.infer_config.draw_threshold:
                score = float(bbox[1])
                print(f"{int(bbox[0])} {bbox[1]} "
                      f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
                label = [0.0]
                # output_label = classes[list(self.label_map.get(label, label))[0]]
                output_label = 'tower'
                print(f'>>> output_label: {output_label}')
                if output_label not in self.labels_in_config:
                    print(output_label + ' label not found in project config.')
                    continue

                x, y, xmax, ymax = bbox[2],bbox[3],bbox[4],bbox[5]
                results.append({
                    'from_name': self.from_name,
                    'to_name': self.to_name,
                    'type': 'rectanglelabels',
                    'value': {
                        'rectanglelabels': [output_label],
                        'x': float(x) / img_width * 100,
                        'y': float(y) / img_height * 100,
                        'width': (float(xmax) - float(x)) / img_width * 100,
                        'height': (float(ymax) - float(y)) / img_height * 100
                    },
                    'score': score
                })
                all_scores.append(score)

        avg_score = sum(all_scores) / max(len(all_scores), 1)
        print(f'>>> RESULTS: {results}')
        return [{'result': results, 'score': avg_score}]


def json_load(file, int_keys=False):
    with io.open(file, encoding='utf8') as f:
        data = json.load(f)
        if int_keys:
            return {int(k): v for k, v in data.items()}
        else:
            return data

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值