1、模型选择,以及分类类型:
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NETS = {
'vgg16': ('VGG16',
'VGG16_faster_rcnn_final.caffemodel'),
'zf': ('ZF',
'ZF_faster_rcnn_final.caffemodel')}
CLASSES后面的是你需要分类目标的名称,NETS后面的是你训练好的模型的名称。
2、vis_detections函数,用来使得检测结果可视化,即在图片中展示出检测结果,包括物体框和类别以及得分。
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.whe