facenet代码注释

facenet识别中的embedding代码块

import tensorflow as tf
import numpy as np
import sys
import os
import copy
sys.path.append('../align/')
from MtcnnDetector import MtcnnDetector
from detector import Detector
from fcn_detector import FcnDetector
from model import P_Net,R_Net,O_Net
from utils import *
import config
import cv2
import h5py


# In[2]:


def main():
    path='../pictures/embeddings.h5'
    if os.path.exists(path):
        print('生成完了别再瞎费劲了!!!')
        return
    img_arr,class_arr=align_face()
    # tf.Graph()
    # 1、它可以通过tensorboard用图形化界面展示出来流程结构
    # 2、它可以整合一段代码为一个整体存在于一个图中
    with tf.Graph().as_default():
        with tf.Session() as sess:
            #加载模型
            load_model('../model/')
            #得到输入与输出tensors
            #images_placeholder是输入图像的占位符,后面会把images传给它
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            #embeddings就是卷积网络最后输出的‘‘特征’’
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            #phase_train_placeholder决定了目前是否为训练阶段
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
            keep_probability_placeholder= tf.get_default_graph().get_tensor_by_name('keep_probability:0')

            # 前向传播计算embeddings
            # feed_dict的作用是给placeholder创建的tensor赋值的
            feed_dict = { images_placeholder: img_arr, phase_train_placeholder:False ,keep_probability_placeholder:1.0}
            embs = sess.run(embeddings, feed_dict=feed_dict)
    f=h5py.File('../pictures/embeddings.h5','w')
    class_arr=[i.encode() for i in class_arr]
    f.create_dataset('class_name',data=class_arr)
    f.create_dataset('embeddings',data=embs)
    f.close()


# In[3]:
#使用MTCNN网络在原始图片中进行检测和对齐

def align_face(path='../pictures/'):
    # 三个网络的阈值
    thresh=config.thresh
    # 设定最小脸大小
    min_face_size=config.min_face
    #PNet图片缩小倍数
    stride=config.stride
    # 测试选择的网络
    test_mode=config.test_mode
    #设detector默认值None
    detectors=[None,None,None]
    # 模型放置位置
    model_path=['../align/model/PNet/','../align/model/RNet/','../align/model/ONet']
    batch_size=config.batches
    #对于PNet,其FcnDetector用于识别单张图片
    PNet=FcnDetector(P_Net,model_path[0])
    detectors[0]=PNet

    # 对于RNet,ONet,其Detector用于识别多张图片
    if test_mode in ["RNet", "ONet"]:
        RNet = Detector(R_Net, 24, batch_size[1], model_path[1])
        detectors[1] = RNet


    if test_mode == "ONet":
        ONet = Detector(O_Net, 48, batch_size[2], model_path[2])
        detectors[2] = ONet

    mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,
                                   stride=stride, threshold=thresh)
   
    #选用图片
    img_paths=os.listdir(path)
    # 获取图片类别和路径
    class_names=[a.split('.')[0] for a in img_paths]
    img_paths=[os.path.join(path,p) for p in img_paths]
    scaled_arr=[]
    class_names_arr=[]
    
    for image_path,class_name in zip(img_paths,class_names):
        #cv2.imread读入图片
        img = cv2.imread(image_path)
          #cv2.imshow()函数可以在窗口中显示图像,参数分别为窗口名字和图像
          #cv2.waitkey()为键盘绑定函数,参数表示等待毫秒数,0表示无期限等待键盘输入
#         cv2.imshow('',img)
#         cv2.waitKey(0)
        try:
            boxes_c,_=mtcnn_detector.detect(img)
        except:
            print('识别不出图像:{}'.format(image_path))
            continue
        #人脸框数量
        num_box=boxes_c.shape[0]
        if num_box>0:
            det=boxes_c[:,:4]
            det_arr=[]
            #原图片大小
            img_size=np.asarray(img.shape)[:2]
            if num_box>1:
                #如果保留一张脸,但存在多张,只保留置信度最大的
                score=boxes_c[:,4]
                index=np.argmax(score)
                det_arr.append(det[index,:])
            else:
                # 只有一个人脸框的话,那就没得选了
                det_arr.append(np.squeeze(det))
            for i,det in enumerate(det_arr):
                det=np.squeeze(det)
                #边界框周围的裁剪边缘,获得左上角和右下角的坐标
                bb=[int(max(det[0],0)), int(max(det[1],0)), int(min(det[2],img_size[1])), int(min(det[3],img_size[0]))]
                #截取图片
                cropped = img[bb[1]:bb[3],bb[0]:bb[2],:]
                #图片截成160*160大小以便作为facenet的输入,并归一化处理
                scaled =cv2.resize(cropped,(160, 160),interpolation=cv2.INTER_LINEAR)-127.5/128.0
                scaled_arr.append(scaled)
                class_names_arr.append(class_name)
    
        else:
            print('图像不能对齐 "%s"' % image_path)
    scaled_arr=np.asarray(scaled_arr)
    class_names_arr=np.asarray(class_names_arr)
    return scaled_arr,class_names_arr


# In[4]:


def load_model(model_dir,input_map=None):
    '''重载模型'''
    
    ckpt = tf.train.get_checkpoint_state(model_dir)                         
    saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   
    saver.restore(tf.get_default_session(), ckpt.model_checkpoint_path)


# In[ ]:


if __name__=='__main__':
    main()


test部分代码块

import tensorflow as tf
import numpy as np
import sys
import os
import copy
from embeddings import load_model
sys.path.append('../align/')
from MtcnnDetector import MtcnnDetector
from detector import Detector
from fcn_detector import FcnDetector
from model import P_Net,R_Net,O_Net
import config
import cv2
import h5py
# 识别人脸阈值
THRED=0.002


# In[2]:


def main():
    #读取对比图片的embeddings和class_name
    f=h5py.File('../pictures/embeddings.h5','r')
    class_arr=f['class_name'][:]
    class_arr=[k.decode() for k in class_arr]
    emb_arr=f['embeddings'][:]
    # 看是读取摄像头数据还是图片
    # 读取摄像头,0为摄像头索引,当有多个摄像头时,从0开始编号
    cap=cv2.VideoCapture(0)
    # cv2.VideoWriter()指定写入视频帧编码格式,fourcc(*'XVID')设置视频的解码器
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    path='../output'
    if not os.path.exists(path):
        os.mkdir(path)
    out = cv2.VideoWriter(path+'/out.mp4' ,fourcc,10,(640,480))
    #加载MTCNN模型
    mtcnn_detector=load_align()
    with tf.Graph().as_default():
        with tf.Session() as sess:
            load_model('../model/')
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
            keep_probability_placeholder= tf.get_default_graph().get_tensor_by_name('keep_probability:0')
            while True:
                    #t=GetTickCount()求一段代码的运行时间,单位毫秒ms,返回的时间。
                    t1=cv2.getTickCount()
                    # 捕获一帧图像
                    # 第一个参数ret 为True 或者False,代表有没有读取到图片
                    # 第二个参数frame表示截取到一帧的图片
                    ret,frame = cap.read()
                    if ret == True:
                        img,scaled_arr,bb_arr=align_face(frame,mtcnn_detector)
                        if scaled_arr is not None:
                            feed_dict = { images_placeholder: scaled_arr, phase_train_placeholder:False ,keep_probability_placeholder:1.0}
                            embs = sess.run(embeddings, feed_dict=feed_dict)
                            face_num=embs.shape[0]
                            face_class=['Others']*face_num
                            for i in range(face_num):
                                diff=np.mean(np.square(embs[i]-emb_arr),axis=1)
                                min_diff=min(diff)
                                print(min_diff)
                                #小于阈值则归为一类,同一个人
                                if min_diff<THRED:
                                    index=np.argmin(diff)
                                    face_class[i]=class_arr[index]
                            #t=GetTickCount()求一段代码的运行时间,单位毫秒ms,返回的时间。
                            t2=cv2.getTickCount()
                            # getTickFrequency()此函数返回每秒内时钟的周期数
                            t=(t2-t1)/cv2.getTickFrequency()
                            #fps是指画面每秒传输帧数
                            fps=1.0/t
                            for i in range(face_num):
                                bbox=bb_arr[i]
                                # cv2.putText()这个函数是opencv里面向图像上添加文本内容的函数,各参数依次是:图片,添加的文字,左上角坐标,字体,字体大小,颜色,字体粗细
                                cv2.putText(img, '{}'.format(face_class[i]), 
                                        (bbox[0], bbox[1] - 2), 
                                        cv2.FONT_HERSHEY_SIMPLEX,
                                        0.5,(0, 255, 0), 2)
                            
                                #画fps值
                                cv2.putText(img, '{:.4f}'.format(t) + " " + '{:.3f}'.format(fps), (10, 20),
                                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 2)
                        else:
                            img=frame
                        #保存图像
                        a = out.write(img)
                        # cv2.imshow()函数可以在窗口中显示图像,参数分别为窗口名字和图像
                        cv2.imshow("result", img)

                        # waitKey(1)的数字代表等待按键输入之前的无效时间,单位为毫秒,在这个时间段内按键 ‘q’ 不会被记录,在这之后按键才会被记录,并在下一次进入if语段时起作用。也即经过无效时间以后,检测在上一次显示图像的时间段内按键 ‘q’ 有没有被按下,若无则跳出if语句段,捕获并显示下一帧图像。
                        #若此参数置零,则代表在捕获并显示了一帧图像之后,程序将停留在if语句段中一直等待 ‘q’ 被键入。
                        #cv2.waitKey(1)与 0xFF(11111111)相与是因为cv2.waitKey(1)的返回值不止8位,但是只有后8位实际有效,为避免产干扰,通过 ‘与’ 操作将其余位置0。
                        if cv2.waitKey(1) & 0xFF == ord('q'):
                            break
                    else:
                        break
            #关闭读取操作,写入操作和相应的显示窗口。
            cap.release()
            out.release()
            cv2.destroyAllWindows()


# In[3]:


def load_align():
    thresh=config.thresh
    min_face_size=config.min_face
    stride=config.stride
    test_mode=config.test_mode
    detectors=[None,None,None]
    # 模型放置位置
    model_path=['../align/model/PNet/','../align/model/RNet/','../align/model/ONet']
    batch_size=config.batches
    PNet=FcnDetector(P_Net,model_path[0])
    detectors[0]=PNet


    if test_mode in ["RNet", "ONet"]:
        RNet = Detector(R_Net, 24, batch_size[1], model_path[1])
        detectors[1] = RNet


    if test_mode == "ONet":
        ONet = Detector(O_Net, 48, batch_size[2], model_path[2])
        detectors[2] = ONet

    mtcnn_detector = MtcnnDetector(detectors=detectors, min_face_size=min_face_size,
                                   stride=stride, threshold=thresh)
    return mtcnn_detector


# In[4]:


def align_face(img,mtcnn_detector):

    try:
        boxes_c,_=mtcnn_detector.detect(img)
    except:
        print('找不到脸')
        return None,None,None
    #人脸框数量
    num_box=boxes_c.shape[0]
    bb_arr=[]
    scaled_arr=[]
    if num_box>0:
        det=boxes_c[:,:4]
        det_arr=[]
        img_size=np.asarray(img.shape)[:2]
        for i in range(num_box):
            det_arr.append(np.squeeze(det[i]))
            
        for i,det in enumerate(det_arr):
            det=np.squeeze(det)
            bb=[int(max(det[0],0)), int(max(det[1],0)), int(min(det[2],img_size[1])), int(min(det[3],img_size[0]))]
            cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),(0,255,0),2)
            bb_arr.append([bb[0],bb[1]])
            cropped = img[bb[1]:bb[3],bb[0]:bb[2],:]
            scaled =cv2.resize(cropped,(160,160),interpolation=cv2.INTER_LINEAR)
            scaled=cv2.cvtColor(scaled,cv2.COLOR_BGR2RGB)-127.5/128.0
            scaled_arr.append(scaled)
        scaled_arr=np.array(scaled_arr)
        return img,scaled_arr,bb_arr
    else:
        print('找不到脸 ')
        return None,None,None
      


# In[5]:


if __name__=='__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值