import cv2
import numpy as np
def random_colors(N): # 定义随机颜色函数
np.random.seed(1)
colors=[tuple(255*np.random.rand(3)) for _ in range(N)]
return colors
def apply_mask(image, mask, color, alpha=0.5):
"""在图片中增加掩码
"""
for n, c in enumerate(color):
image[:, :, n] = np.where(
mask == 1,
image[:, :, n] *(1 - alpha) + alpha * c,
image[:, :, n]
) #若掩码为1,则更新颜色
return image
#可视化每个实例
def display_instances(image,boxes,masks,ids,names,scores):
n_instances=boxes.shape[0]
if not n_instances:
print('No instances to display')
else:
assert boxes.shape[0] == masks.shape[-1] == ids.shape[0]
colors=random_colors(n_instances)
height, width = image.shape[:2] #图像大小
for i,color in enumerate(colors):
if not np.any(boxes[i]):
continue
label=names[ids[i]]
# if label=='person':
# image[:, :, i] = np.where(
# mask == 1,
# image[:, :, i] *(1 - alpha) + alpha * c,
# image[:, :, i]
# )
y1,x1,y2,x2=boxes[i] #box坐标
mask=masks[:,:,i]
image=apply_mask(image,mask,color)
image=cv2.rectangle(image,(x1,y1),(x2,y2),color,2) #画出检测框
score=scores[i] if scores is not None else None
caption='{}{:.2f}'.format(label,score) if score else label
image=cv2.putText(
image,caption,(x1,y1),cv2.FONT_HERSHEY_COMPLEX,0.7,color,2
) #增加类别及得分
return image
if __name__=='__main__':
import os
import sys
import random
import math
import skimage.io
import time
import utils
#import model as modellib
ROOT_DIR = os.path.abspath("../") #根目录地址
sys.path.append(ROOT_DIR)
from mrcnn import utils
import mrcnn.model as modellib
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))
import coco
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5") #预训练参数地址
if not os.path.exists(COCO_MODEL_PATH):
print('cannot find coco_model')
#本机使用GTX-960M显卡4GB显存,只能处理一张图
class InferenceConfig(coco.CocoConfig):
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
config.display()
model = modellib.MaskRCNN(
mode="inference", model_dir=MODEL_DIR, config=config
)
#导入权重
model.load_weights(COCO_MODEL_PATH, by_name=True)
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
#打开摄像头
capture=cv2.VideoCapture(0)
capture.set(cv2.CAP_PROP_FRAME_WIDTH,1920)
capture.set(cv2.CAP_PROP_FRAME_HEIGHT,1080) #设置输出分辨率
while True:
ret,frame=capture.read() #获得帧
results=model.detect([frame],verbose=0) #将每帧实时图像导入模型
r=results[0]
frame=display_instances(
frame,r['rois'], r['masks'], r['class_ids'],
class_names, r['scores']
)
cv2.imshow('realtime instance',frame)
if cv2.waitKey(1)&0xFF==ord('q'):
break
capture.release()
cv2.destroyAllWindows()
mask R-cnn
最新推荐文章于 2023-01-18 10:27:39 发布