Lagent调用mmdetection api

1. 安装mmdetection

1.1 使用 MIM 安装 MMEngine 和 MMCV

pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"

1.2 安装mmdet

mim install mmdet

2. 调用mmdet api

2.1 导入DetInferencer,初始化模型

使用模型为RTMDet,实时目标检测模型(Real-Time Models for object Detection)。

该模型支持识别大概80个类别,如下。

#mmdetection
from mmdet.apis import DetInferencer

#initial model
inferencer_mmdet = DetInferencer(model='rtmdet_tiny_8xb32-300e_coco')

# COCO dataset class
classes_cocodataset = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 
           'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
           'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
           'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
           'kite', 'baseball bat', 'baseball glove', 'skateboard',
           'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
           'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
           'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
           'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
           'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
           'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
           'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
           'teddy bear', 'hair drier', 'toothbrush')

2.2 推理和结果

result = inferencer_mmdet(image_path, out_dir='./outputs/', no_save_pred=False)   
result_prediction = result.get('predictions')[0]
result_labels = result_prediction.get('labels')
result_scores = result_prediction.get('scores')
result_bboxes = result_prediction.get('bboxes')
            
#only the first class
image_class = classes_cocodataset[result_labels[0]]

推理得到的结果有三个:labels,scores,bboxes。

目前只取score最高的那一个class作为imag_class。

完整代码:

import os
from typing import List, Optional, Tuple, Union

import requests
import json

from lagent.schema import ActionReturn, ActionStatusCode
from .base_action import BaseAction

#mmdetection
from mmdet.apis import DetInferencer

#initial model
inferencer_mmdet = DetInferencer(model='rtmdet_tiny_8xb32-300e_coco')

# COCO dataset class
classes_cocodataset = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 
           'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
           'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
           'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
           'kite', 'baseball bat', 'baseball glove', 'skateboard',
           'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
           'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
           'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
           'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
           'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
           'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
           'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
           'teddy bear', 'hair drier', 'toothbrush')


DEFAULT_DESCRIPTION = """一个进行图片识别的API。
当你需要对于一个图片进行识别时,可以使用这个API。
优先使用ImageRecognition来进行图片识别。
输入应该是一张图片文件的路径,或者是图片的URL。
"""


class ImageRecognition(BaseAction):

    def __init__(self,
                 description: str = DEFAULT_DESCRIPTION,
                 name: Optional[str] = None,
                 enable: bool = True,
                 disable_description: Optional[str] = None) -> None:
        super().__init__(description)


    def __call__(self, query: str) -> ActionReturn:
        """Return the image recognition response.

        Args:
            query (str): The query include the image content path.

        Returns:recognition
            ActionReturn: The action return.
        """

        tool_return = ActionReturn(url=None, args=None, type=self.name)
        try:
            response = self._image_recognition(query)
            tool_return.result = dict(text=str(response))
            tool_return.state = ActionStatusCode.SUCCESS
        except Exception as e:
            tool_return.result = dict(text=str(e))
            tool_return.state = ActionStatusCode.API_ERROR
        return tool_return

    def _image_recognition(self,
                query: str) -> str:
        print("Enter Image Recognition entry")
        data = json.loads(query)
        image_path = data.get("image_path", None)
        if image_path is not None:
            result = inferencer_mmdet(image_path, out_dir='./outputs/', no_save_pred=False, print_result=False)   
            result_prediction = result.get('predictions')[0]
            result_labels = result_prediction.get('labels')
            result_scores = result_prediction.get('scores')
            result_bboxes = result_prediction.get('bboxes')

            #only the first class
            image_class = classes_cocodataset[result_labels[0]]
        else:
            print("image_path不存在")
            image_class = "unknown"
        return 'image recognition response here is a ' + image_class

  • 10
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值