使用caffe验证LFW的6000对图片

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
import os
import caffe
import cv2
import numpy as np

import operator
import pickle

from sklearn.model_selection import KFold
from scipy import interpolate

LabelList=[]
#使用正则
pattern = re.compile('\S+')

#LFW文件夹目录
LFW_Dir='../aligned-lfw'

#图片库路径,若label.txt中使用的是绝对路径,则无需修改
imageBasePath=''

#caffe_model 文件夹目录
caffe_model_dir ='../model_42302/model/4096-4096-1024-98.2'

#窗口1 显示图片
# cv2.namedWindow('input_image_1',cv2.WINDOW_AUTOSIZE)
# cv2.namedWindow('input_image_2',cv2.WINDOW_AUTOSIZE)
#模型初始化
def initilize_model(Prototxt,ModelFile):
    print('model initilzing ...')
    if os.path.exists(Prototxt):
        if not os.access(Prototxt,os.W_OK|os.R_OK):
            print('Prototxt file not read,or write')
            return false
    if os.path.exists(ModelFile):
        if not os.access(ModelFile,os.W_OK|os.R_OK):
            print('ModelFile file not read,or write')
            return false
    deployPrototxt = Prototxt
    modelFile = ModelFile
    caffe.set_mode_cpu()
    net = caffe.Net(deployPrototxt, modelFile,caffe.TEST)
    print('init model end ...')
    return net

#余弦相似度
def cos(vector1,vector2):
    dot_product = 0.0;
    normA = 0.0;
    normB = 0.0;
    for a,b in zip(vector1,vector2):
        dot_product += a*b
        normA += a**2
        normB += b**2
    if (normA == 0.0) or (normB==0.0):
        return None
    else:
        return dot_product / ((normA*normB)**0.5)
def cos_sim(vector_a, vector_b):
    """
    计算两个向量之间的余弦相似度
    :param vector_a: 向量 a
    :param vector_b: 向量 b
    :return: sim
    """
    vector_a = np.mat(vector_a)
    vector_b = np.mat(vector_b)
    num = float(vector_a * vector_b.T)
    denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
    cos = num / denom
    sim = 0.5 + 0.5 * cos
    return sim

#向量之间的欧氏距离
def Euclidean_dist(vector_a, vector_b):
    dist = np.sqrt(np.sum(np.square(vector_a - vector_b)))
    return dist
#写入到文件
def save_dist_tofile(distance_list):
    dist_save_file = open('dist.csv','w+')
    if len(distance_list) > 0:
        lists=[str(line)+"\n" for line in distance_list]#带了 \n
        dist_save_file.writelines(lists) #以字符串的形式存入了文件..方便查看数据是否正确
        dist_save_file.close()
        print('distances write to file ,name of dist.csv , status is : OK')
    else:
        dist_save_file.close()
        print('distance_list size == 0')

def save_label_tofile(LabelList):
    if len(LabelList) == 0:
        print('')
        return None
    pickle_file = open('labellist.pkl','wb')
    pickle.dump(LabelList,pickle_file)#这里将list数据写入到文件是用了 pickle模块,二进制
    pickle_file.close()
    print('save  label to file ,OK')

#从文件读
def read_label_toList(labelfilename):
    pickle_file = open(labelfilename,'rb')
    list = pickle.load(pickle_file)
    pickle_file.close()
    return list

def read_csv_toList(csvname):
    file = open(csvname)
    list = []
    try:
        while 1:
            text = file.readline()
            if not text:
                break;
            else:
                text=text.strip('\n')#去掉换行符
                #print(text)
                list.append(float(text))#转成float
    except KeyboardInterrupt:
        pass
    finally:
        pass
    file.close()
    print('read csv file, OK')
    if len(list) == 0:
        print('read file,but ..to list,is empty')
        return None
    return list

def show_pic(pic_1,pic_2,sleep_time=1):
    src1 = cv2.imread(pic_1)
    src2 = cv2.imread(pic_2)
    cv2.imshow('input_image_1',src1)
    cv2.imshow('input_image_2',src2)
    cv2.waitKey(sleep_time*1000)
    pass

#提取特征
def extractFeature(leftImageList,rightImageList,net):
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2,0,1))
    transformer.set_raw_scale('data', 255)
    transformer.set_channel_swap('data', (2,1,0))
    net.blobs['data'].reshape(1, 3,224, 224)

    leftfeature=[]
    rightfeature=[]
    cos_Valuelist=[]
    dist_valuelist=[]
    for i in range(len(leftImageList)):
        #读取左边图像
        imageleft = os.path.join(imageBasePath,leftImageList[i])
        img = caffe.io.load_image(imageleft)
        #print imageleft, img.shape
        img = caffe.io.resize_image(img,(224,224))
        net.blobs['data'].data[...] = transformer.preprocess('data',img)
        out = net.forward()
        feature_1 = np.float64(net.blobs['fc9'].data)
        leftfeature.append(feature_1)

        #读取右边图像
        imageright = os.path.join(imageBasePath,rightImageList[i])
        img = caffe.io.load_image(imageright)
        img = caffe.io.resize_image(img,(224,224))
        net.blobs['data'].data[...] = transformer.preprocess('data',img)
        out = net.forward()
        feature_2 = np.float64(net.blobs['fc9'].data)
        rightfeature.append(feature_2)

        #cos_value = cos_sim(feature_1,feature_2)
        #print(cos_value)
        #cos_Valuelist.append(cos_value)
        dist_value = Euclidean_dist(feature_1,feature_2)
        dist_valuelist.append(dist_value)
        #print(dist_value)
    #save_dist_tofile(dist_valuelist)
    return dist_valuelist#,cos_Valuelist  #在这里,我只需要欧式距离 list

#caffe_model file path
caffe_deploy_filepath = os.path.join(caffe_model_dir ,'facenet_deploy.prototxt')
caffe_model_filepath = os.path.join(caffe_model_dir , 'facenet_iter_100000.caffemodel')

#net = initilize_model(caffe_deploy_filepath,caffe_model_filepath)

def evaluate(distances, labels, nrof_folds=10):
    # 生成3000个间隔点
    thresholds = np.arange(0, 30, 0.01)
    # 找出准确率最高的点的准确率
    tpr, fpr, accuracy = calculate_roc(thresholds, distances,
        labels, nrof_folds=nrof_folds)
    return tpr, fpr, accuracy

def calculate_roc(thresholds, distances, labels, nrof_folds=10):
    print('calculate_roc ---')
    nrof_pairs = min(len(labels), len(distances))
    nrof_thresholds = len(thresholds)
    k_fold = KFold(n_splits=nrof_folds, shuffle=False)

    tprs = np.zeros((nrof_folds,nrof_thresholds))
    fprs = np.zeros((nrof_folds,nrof_thresholds))
    accuracy = np.zeros((nrof_folds))

    indices = np.arange(nrof_pairs)

    # 10折交叉验证
    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):

        acc_train = np.zeros((nrof_thresholds))
        for threshold_idx, threshold in enumerate(thresholds):
            # 在训练集寻找准确度最高的点
            _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, distances[train_set], labels[train_set])
        best_threshold_index = np.argmax(acc_train)
        for threshold_idx, threshold in enumerate(thresholds):
            # 在测试集上测试准确率
            tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, distances[test_set], labels[test_set])
        _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], distances[test_set], labels[test_set])

        tpr = np.mean(tprs,0)
        fpr = np.mean(fprs,0)
    return tpr, fpr, accuracy

def calculate_accuracy(threshold, dist, actual_issame):
    predict_issame = np.less(dist, threshold)
    # true positive
    tp = np.sum(np.logical_and(predict_issame, actual_issame))
    # false positive
    fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
    # ture negative
    tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
    # false negative
    fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))

    # 真阳性率
    tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)
    # 伪阳性率
    fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)
    # 准确率
    acc = float(tp+tn)/dist.size
    return tpr, fpr, acc


get_leftimage_path =[]
get_rightimage_path=[]

pairspath ='../pairs.txt'

def loop_read_imagepath(pairspath):
    file = open(pairspath)
    lineNum = 0
    leftNum = 0
    rightNum = 0
    try:
        while 1:
            line = file.readline()
            if not line:
                #print(lineNum)
                print('leftNum: '+ str(leftNum))
                print('rightNum: '+ str(rightNum))
                break;
            else:
                lineNum +=1
                findonePerson = pattern.findall(line)
                if len(findonePerson) == 3:
                    leftNum +=1
                    picpath1 = findonePerson[0] + '_{0:0>4}.jpg'.format(findonePerson[1])
                    picpath2 = findonePerson[0] + '_{0:0>4}.jpg'.format(findonePerson[2])
                    #print([picpath1,picpath2])
                    pic_relativepath1 = os.path.join(LFW_Dir,findonePerson[0])
                    pic_relativepath2 = os.path.join(LFW_Dir,findonePerson[0])
                    pic_relativepath1 = os.path.join(pic_relativepath1,picpath1)
                    pic_relativepath2 = os.path.join(pic_relativepath2,picpath2)
                    get_leftimage_path.append(pic_relativepath1)
                    get_rightimage_path.append(pic_relativepath2)
                    LabelList.append(1)

                elif len(findonePerson) == 4:
                    rightNum += 1
                    picpath1 = findonePerson[0] + '_{0:0>4}.jpg'.format(findonePerson[1])
                    picpath2 = findonePerson[2] + '_{0:0>4}.jpg'.format(findonePerson[3])
                    pic_relativepath1 = os.path.join(LFW_Dir,findonePerson[0])
                    pic_relativepath2 = os.path.join(LFW_Dir,findonePerson[2])
                    pic_relativepath1 = os.path.join(pic_relativepath1,picpath1)
                    pic_relativepath2 = os.path.join(pic_relativepath2,picpath2)
                    get_leftimage_path.append(pic_relativepath1)
                    get_rightimage_path.append(pic_relativepath2)
                    LabelList.append(0)
                else:
                    pass
    except KeyboardInterrupt:
        pass
    finally:
        #cv2.destroyAllWindows()
        pass


#loop_read_imagepath(pairspath)
#save_label_tofile(LabelList)
#print(get_leftimage_path)
#distances_list = extractFeature(get_leftimage_path,get_rightimage_path,net)

#
# save_label_tofile(LabelList)
distances_list = read_csv_toList('dist.csv')

LabelList = read_label_toList('labellist.pkl')

tpr, fpr, accuracy = evaluate(np.array(distances_list),np.array(LabelList))

print('---tpr---: ')
print(tpr)
print('---fpr---: ')
print(fpr)
print('---accuracy---: ')
print(accuracy)

参考连接:写得不错这个链接,自己有改动

对于代码中的model   proto 文件,我这边暂不提供..

pairs.txt 文件.   LFW官方的..

如有疑问,请联系我..欢迎跟我交流

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Teleger

你的支持是我前进的方向

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值