mask R-cnn

import cv2  
import numpy as np  
  
  
def random_colors(N):  # 定义随机颜色函数  
    np.random.seed(1)  
    colors=[tuple(255*np.random.rand(3)) for _ in range(N)]  
    return colors  
  
def apply_mask(image, mask, color, alpha=0.5):  
    """在图片中增加掩码
    """  
    for n, c in enumerate(color):  
        image[:, :, n] = np.where(  
            mask == 1,  
            image[:, :, n] *(1 - alpha) + alpha * c,  
            image[:, :, n]  
        )  #若掩码为1,则更新颜色
    return image


#可视化每个实例
def display_instances(image,boxes,masks,ids,names,scores):  
    n_instances=boxes.shape[0]  
    if not n_instances:  
        print('No instances to display')  
    else:  
        assert boxes.shape[0] == masks.shape[-1] == ids.shape[0]  
      
    colors=random_colors(n_instances)  
    height, width = image.shape[:2]  #图像大小
      
    for i,color in enumerate(colors):  
        if not np.any(boxes[i]):  
            continue  


        label=names[ids[i]]
       # if label=='person':
       #     image[:, :, i] = np.where(  
       #     mask == 1,  
       #     image[:, :, i] *(1 - alpha) + alpha * c,  
       #     image[:, :, i]
       #     )
        y1,x1,y2,x2=boxes[i]  #box坐标
        mask=masks[:,:,i]   
        image=apply_mask(image,mask,color)  
        image=cv2.rectangle(image,(x1,y1),(x2,y2),color,2)  #画出检测框
          


        score=scores[i] if scores is not None else None  
          
        caption='{}{:.2f}'.format(label,score) if score else label  
        image=cv2.putText(  
            image,caption,(x1,y1),cv2.FONT_HERSHEY_COMPLEX,0.7,color,2
        )  #增加类别及得分
          
    return image  
  
if __name__=='__main__':  
    import os  
    import sys  
    import random  
    import math  
    import skimage.io  
    import time  
    import utils  
    #import model as modellib  
      
      
    ROOT_DIR = os.path.abspath("../")    #根目录地址
    sys.path.append(ROOT_DIR)  
    from mrcnn import utils  
    import mrcnn.model as modellib  
  
   
    sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))   
    import coco  
      
  
    MODEL_DIR = os.path.join(ROOT_DIR, "logs")  
    COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")  #预训练参数地址
    if not os.path.exists(COCO_MODEL_PATH):  
        print('cannot find coco_model')
        
   #本机使用GTX-960M显卡4GB显存,只能处理一张图         
    class InferenceConfig(coco.CocoConfig):  
        GPU_COUNT = 1  
        IMAGES_PER_GPU = 1  
  
    config = InferenceConfig()  
    config.display()  
      
    model = modellib.MaskRCNN(  
        mode="inference", model_dir=MODEL_DIR, config=config  
    )  
  
    #导入权重 
    model.load_weights(COCO_MODEL_PATH, by_name=True)  
    class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',  
               'bus', 'train', 'truck', 'boat', 'traffic light',  
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',  
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',  
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',  
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',  
               'kite', 'baseball bat', 'baseball glove', 'skateboard',  
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',  
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',  
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',  
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',  
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',  
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',  
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',  
               'teddy bear', 'hair drier', 'toothbrush']


    #打开摄像头
    capture=cv2.VideoCapture(0)  
    capture.set(cv2.CAP_PROP_FRAME_WIDTH,1920)  
    capture.set(cv2.CAP_PROP_FRAME_HEIGHT,1080)  #设置输出分辨率
      
    while True:  
        ret,frame=capture.read()  #获得帧
        results=model.detect([frame],verbose=0) #将每帧实时图像导入模型
        r=results[0]  
          
          
        frame=display_instances(  
              frame,r['rois'], r['masks'], r['class_ids'],   
                            class_names, r['scores']  
        )  
          
        cv2.imshow('realtime instance',frame)  
        if cv2.waitKey(1)&0xFF==ord('q'):  
            break  
         
    capture.release()  
    cv2.destroyAllWindows() 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值