pytorch-FaceNet(添加face match function)

FaceNet class(见 bubbliiiing-github

class Facenet(object):
    _defaults = {
        #--------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测要修改model_path,指向logs文件夹下的权值文件
        #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
        #   验证集损失较低不代表准确度较高,仅代表该权值在验证集上泛化性能较好。
        #--------------------------------------------------------------------------#
        "model_path"    : "model_data/facenet_inception_resnetv1.pth",
        #--------------------------------------------------------------------------#
        #   输入图片的大小。
        #--------------------------------------------------------------------------#
        "input_shape"   : [160, 160, 3],
        #--------------------------------------------------------------------------#
        #   所使用到的主干特征提取网络
        #--------------------------------------------------------------------------#
        "backbone"      : "inception_resnetv1",
        #-------------------------------------------#
        #   是否进行不失真的resize
        #-------------------------------------------#
        "letterbox_image"   : True,
        #-------------------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------------------#
        "cuda"              : True,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化Facenet
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)

        self.generate()
        
    def generate(self):
        #---------------------------------------------------#
        #   载入模型与权值
        #---------------------------------------------------#
        print('Loading weights into state dict...')
        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net    = facenet(backbone=self.backbone, mode="predict").eval()
        self.net.load_state_dict(torch.load(self.model_path, map_location=device), strict=False)
        print('{} model loaded.'.format(self.model_path))

        if self.cuda:
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = True
            self.net = self.net.cuda()
    
    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image_1, image_2,mode):
        #---------------------------------------------------#
        #   图片预处理,归一化
        #---------------------------------------------------#
        with torch.no_grad():
            image_1 = resize_image(image_1, [self.input_shape[1], self.input_shape[0]], letterbox_image=self.letterbox_image)
            image_2 = resize_image(image_2, [self.input_shape[1], self.input_shape[0]], letterbox_image=self.letterbox_image)
            
            photo_1 = torch.from_numpy(np.expand_dims(np.transpose(preprocess_input(np.array(image_1, np.float32)), (2, 0, 1)), 0))
            photo_2 = torch.from_numpy(np.expand_dims(np.transpose(preprocess_input(np.array(image_2, np.float32)), (2, 0, 1)), 0))
            
            if self.cuda:
                photo_1 = photo_1.cuda()
                photo_2 = photo_2.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            output1 = self.net(photo_1).cpu().numpy()
            output2 = self.net(photo_2).cpu().numpy()
            
            #---------------------------------------------------#
            #   计算二者之间的距离
            #---------------------------------------------------#
            l1 = np.linalg.norm(output1 - output2, axis=1)

        if mode == True:
            plt.subplot(1, 2, 1)
            plt.imshow(np.array(image_1))

            plt.subplot(1, 2, 2)
            plt.imshow(np.array(image_2))
            plt.text(-12, -12, 'Distance:%.3f' % l1, ha='center', va= 'bottom',fontsize=11)
            plt.show()


        return l1

添加 Face match 功能,一对多的匹配。

import os

from PIL import Image

from facenet import Facenet


def Face_match(face_path, face_database, mode):
    """
    :param face_path: 放需要检测的人脸
    :param face_database: 匹配的人脸库
    :param mode: True/False 是否需要展示匹配的图片(同时显示distance)
    :return: 
    """

    model = Facenet ()

    face_unkonw_path = face_path
    face_known_path = face_database

    face_unknown = []
    face_database = []

    for file in os.listdir (face_unkonw_path):
        face_unknown.append (file)
    for file in os.listdir (face_known_path):
        face_database.append (file)

    print ('-' * 60)
    print ("Unknowm faces: {} ".format (face_unknown))
    print ("Face database: {} ".format (face_database))
    print ('-' * 60)

    for unkonw_face in face_unknown:

        face_distance = []
        face_path_1 = os.path.join (face_unkonw_path, unkonw_face)
        face_img_1 = Image.open (face_path_1)

        for face in face_database:
            face_path_2 = os.path.join (face_known_path, face)
            face_img_2 = Image.open (face_path_2)

            probability = model.detect_image (face_img_1, face_img_2, mode)
            face_distance.append (probability[0])

        index = face_distance.index (min (face_distance))
        print ("The matched face for {} is: {} ".format (unkonw_face, face_database[index]))


if __name__ == "__main__":
    face_unkonw_path = r'FaceImages'
    face_known_path = r'FaceDatabase'
    mode = False

    Face_match (face_unkonw_path, face_known_path, mode)

在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值