#采用自己的模型进行预测,或者直接使用COCO权重进行全部类别的预测看这里:
目录
#采用自己的模型进行预测,或者直接使用COCO权重进行全部类别的预测看这里:
1.问题提出
大家可能都知道,这个版本的COCO一共有如下80个类别:
但是针对某些特定的场景,你可能只需要检测出其中的一种或者多种,例如只检测'traffic light'这一类,而不检测其他的类该怎么做呢?类似下图:
a.直接使用COCO权重预测:
b.经过筛选后,只检测'traffic light'这一类:
2.解决思路
其实解决这个问题的办法很简单,每次调用detect对图像处理后,都可以得到一系列boungding box,而每一个boungding box也都对应这一个个class_id,因此只需要对这些class_id进行筛选就可以得到需要的这些类了。
3.得到class_id
因此,问题就变成了如何拿到这些类对应的class_id了,你可以通过以下语句得到class_id:
class_id = class_names.index('traffic light')
4.筛选出需要的boungding box
拿到这些class_id后就可以按他们筛选出需要的boungding box:
if class_names.index('traffic light') in r['class_ids']:
k = list(np.where(r['class_ids'] == class_names.index('traffic light'))[0])
r['scores'] = np.array([r['scores'][i] for i in k])
r['rois'] = np.array([r['rois'][i] for i in k])
r['masks'] = np.array([r['masks'][i] for i in k])
r['class_ids'] = np.array([r['class_ids'][i] for i in k])
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'], figsize=(8, 8))
其他步骤和上面贴出来的链接一致,对应修改即可!