商汤旋转框检测,预测代码及自定义可视化

# 以下两行别删,模型构建会用
from argparse import ArgumentParser
import mmrotate  

from mmdet.apis import inference_detector, init_detector
import cv2
import numpy as np

def cv_show(name:str,img:np.array):
    # 窗口大小可调整
    cv2.namedWindow(name,0)
    cv2.imshow(name, img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    return

class MMRotate_Test():
    def __init__(self) -> None:
        # 配置文件地址
        self.config_path = "configs/config.py"
        # 权重文件地址
        self.weight_path = "latest.pth"
        # 类名
        self.classes = ('50BT-B-55', 'ZG', 'TCC90', '100HT-B-55', 'TCC30R', 'TCC45L', 'TCC30L', 'TCC45R', '80HT-B-55', 'left', 'right', 'HSB-100-NBR60-IG3-8')
        # 类对应的旋转框颜色
        self.colors = [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
                (138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255),
                (255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139)]
        # 得分阈值显示
        self.score_thread = 0.1
        # device
        self.device = "cuda:0"

        # 模型构建
        self.model = init_detector(self.config_path, self.weight_path, self.device)
    
    # 预测
    def main(self,img:np.array):
        # 预测
        # ! result是一个list,list的长度为class的数量。list的每一项对应每一个class预测到的内容。
        result = inference_detector(self.model, img)

        # 可视化展示
        self.img_visual(result,img)

        return result
    
    # 数据可视化
    def img_visual(self,result,img):
        # 遍历result的每一项,读取每个类预测的内容
        for i in range(len(self.classes)):
            cla = result[i]
            cla_name = self.classes[i]
            cla_color = self.colors[i]
            if len(cla)<1:
                continue
            else:
                # 遍历每个类预测到的目标的信息
                for r in cla:
                    # 每个预测描述了目标的旋转框的中心点坐标,旋转框宽高,以及旋转角(弧度表示)
                    center_x, center_y, width, height,angle,score = r
                    if score<self.score_thread:
                        continue
                    else:
                        angle = (angle*180)/3.14
                        self.draw_rotated_box(img,center_x, center_y, width, height,angle,cla_name,cla_color)
        # 
        cv_show("img",img)
    
    def draw_rotated_box(self,image, center_x, center_y, width, height, angle,cla_name,cla_color):
        # 创建旋转框的四个角点坐标
        rect = ((center_x, center_y), (width, height), angle)
        box = cv2.boxPoints(rect)
        box = np.int0(box)

        # 在图像上绘制文字与旋转框
        cv2.putText(image,cla_name,(int(center_x),int(center_y)),cv2.FONT_HERSHEY_COMPLEX,1,cla_color,1)
        cv2.drawContours(image, [box], 0,cla_color, 2)


if __name__ == "__main__":
    import time
    mm_test = MMRotate_Test()

    img = cv2.imread("data/575.png")

    for i in range(10):
        start = time.time()
        mm_test.main(img)
        end = time.time()
        print(end-start)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值