py-faster-rcnn_caffemodel对人脸进行标注

本程序在py-faster-rcnn/tools/demo.py的基础上进行修改

程序功能:利用训练好的caffemodel,对人脸进行标注

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """ 
  11. Demo script showing detections in sample images. 
  12.  
  13. See README.md for installation instructions before running. 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.config import cfg  
  18. from fast_rcnn.test import im_detect  
  19. from fast_rcnn.nms_wrapper import nms  
  20. from utils.timer import Timer  
  21. import matplotlib.pyplot as plt  
  22. import numpy as np  
  23. import scipy.io as sio  
  24. import caffe, os, sys, cv2  
  25. import argparse  
  26.   
  27. #CLASSES = ('__background__',  
  28. #           'aeroplane', 'bicycle', 'bird', 'boat',  
  29. #           'bottle', 'bus', 'car', 'cat', 'chair',  
  30. #           'cow', 'diningtable', 'dog', 'horse',  
  31. #           'motorbike', 'person', 'pottedplant',  
  32. #           'sheep', 'sofa', 'train', 'tvmonitor')  
  33.   
  34. CLASSES = ('__background__','face')  
  35.   
  36. NETS = {'vgg16': ('VGG16',  
  37.                   'VGG16_faster_rcnn_final.caffemodel'),  
  38.         'myvgg': ('VGG_CNN_M_1024',  
  39.                   'VGG_CNN_M_1024_faster_rcnn_final.caffemodel'),  
  40.         'zf': ('ZF',  
  41.                   'ZF_faster_rcnn_final.caffemodel'),  
  42.         'myzf': ('ZF',  
  43.                   'zf_rpn_stage1_iter_80000.caffemodel'),  
  44. }  
  45.   
  46.   
  47. def vis_detections(im, class_name, dets, thresh=0.5):  
  48.     """Draw detected bounding boxes."""  
  49.     inds = np.where(dets[:, -1] >= thresh)[0]  
  50.     if len(inds) == 0:  
  51.         return  
  52.   
  53.     #write_file.write(array[current_image] + ' ') #add by zhipeng  
  54.     #write_file.write('face' + ' ') #add by zhipeng  
  55.     im = im[:, :, (210)]  
  56.     #fig, ax = plt.subplots(figsize=(12, 12))  
  57.     #ax.imshow(im, aspect='equal')  
  58.     for i in inds:  
  59.         bbox = dets[i, :4]  
  60.         score = dets[i, -1]  
  61.   
  62.         write_file.write(array[current_image] + ' '#add by zhipeng  
  63.         #write_file.write('face' + ' ')  
  64.         ##########   add by zhipeng for write rectange to txt   ########  
  65.         #bbox[0]:x, bbox[1]:y, bbox[2]:x+w, bbox[3]:y+h  
  66.         write_file.write( "{} {} {} {}\n".format(str(int(bbox[0])), str(int(bbox[1])),  
  67.                                                         str(int(bbox[2])-int(bbox[0])),  
  68.                                                         str(int(bbox[3])-int(bbox[1]))))  
  69.         #print "zhipeng, bbox:", bbox, "score:",score  
  70.         ##########   add by zhipeng for write rectange to txt   ########  
  71.   
  72.           
  73.   
  74. def demo(net, image_name):  
  75.     """Detect object classes in an image using pre-computed object proposals."""  
  76.   
  77.     # Load the demo image  
  78.     #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)  
  79.     im = cv2.imread(image_name)  
  80.   
  81.     # Detect all object classes and regress object bounds  
  82.     timer = Timer()  
  83.     timer.tic()  
  84.     scores, boxes = im_detect(net, im)  
  85.     timer.toc()  
  86.     print ('Detection took {:.3f}s for '  
  87.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])  
  88.   
  89.     # Visualize detections for each class  
  90.     CONF_THRESH = 0.8  
  91.     NMS_THRESH = 0.3  
  92.     for cls_ind, cls in enumerate(CLASSES[1:]):  
  93.         cls_ind += 1 # because we skipped background  
  94.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
  95.         cls_scores = scores[:, cls_ind]  
  96.         dets = np.hstack((cls_boxes,  
  97.                           cls_scores[:, np.newaxis])).astype(np.float32)  
  98.         keep = nms(dets, NMS_THRESH)  
  99.         dets = dets[keep, :]  
  100.         vis_detections(im, cls, dets, thresh=CONF_THRESH)  
  101.   
  102. def parse_args():  
  103.     """Parse input arguments."""  
  104.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')  
  105.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',  
  106.                         default=0, type=int)  
  107.     parser.add_argument('--cpu', dest='cpu_mode',  
  108.                         help='Use CPU mode (overrides --gpu)',  
  109.                         action='store_true')  
  110.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',  
  111.                         choices=NETS.keys(), default='vgg16')  
  112.   
  113.     args = parser.parse_args()  
  114.   
  115.     return args  
  116.   
  117. if __name__ == '__main__':  
  118.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals  
  119.   
  120.     args = parse_args()  
  121.   
  122.     prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],  
  123.                             'faster_rcnn_alt_opt''faster_rcnn_test.pt')  
  124.     caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',  
  125.                               NETS[args.demo_net][1])  
  126.   
  127.     if not os.path.isfile(caffemodel):  
  128.         raise IOError(('{:s} not found.\nDid you run ./data/script/'  
  129.                        'fetch_faster_rcnn_models.sh?').format(caffemodel))  
  130.   
  131.     if args.cpu_mode:  
  132.         caffe.set_mode_cpu()  
  133.     else:  
  134.         caffe.set_mode_gpu()  
  135.         caffe.set_device(args.gpu_id)  
  136.         cfg.GPU_ID = args.gpu_id  
  137.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)  
  138.   
  139.     print '\n\nLoaded network {:s}'.format(caffemodel)  
  140.   
  141.     # Warmup on a dummy image  
  142.     im = 128 * np.ones((3005003), dtype=np.uint8)  
  143.     for i in xrange(2):  
  144.         _, _= im_detect(net, im)  
  145.   
  146.     '''''im_names = ['000456.jpg', '000542.jpg', '001150.jpg', 
  147.                 '001763.jpg', '004545.jpg']'''  
  148.   
  149.     ##########   add by zhipeng for write rectange to txt   ########  
  150.     #write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/tools/detections/out.txt'  
  151.     #write_file = open(write_file_name, "w")  
  152.     ##########   add by zhipeng for write rectange to txt   ########  
  153.   
  154. #    for current_file in range(1,11):      #orginal range(1, 11)  
  155.   
  156. #    print 'Processing file ' + str(current_file) + ' ...'  
  157.   
  158.     read_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos_fold/name.txt'  
  159.     write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos_fold/annotate.txt'  
  160.     write_file = open(write_file_name, "w")  
  161.   
  162.     with open(read_file_name, "r") as ins:  
  163.         array = []  
  164.         for line in ins:  
  165.             line = line[:-1]  
  166.             array.append(line)      # list of strings  
  167.   
  168.     number_of_images = len(array)  
  169.   
  170.     for current_image in range(number_of_images):  
  171.         if current_image % 100 == 0:  
  172.             print 'Processing image : ' + str(current_image)  
  173.         # load image and convert to gray  
  174.         read_img_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/pos/' + array[current_image].rstrip()  
  175.         #write_file.write(array[current_image]) #add by zhipeng  
  176.         demo(net, read_img_name)  
  177.   
  178.     write_file.close()  
  179.   
  180.     '''''for im_name in im_names: 
  181.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' 
  182.         print 'Demo for data/demo/{}'.format(im_name) 
  183.         write_file.write(im_name + '\n') #add by zhipeng 
  184.         demo(net, im_name)'''  
  185.   
  186.     #write_file.close()  # add by zhipeng,close file  
  187.     plt.show()  

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值