人脸识别insightface,mxnet,读取模型,提取特征代码

# -*- coding: UTF-8 -*-
import os
import numpy as np
import cPickle
from sklearn.metrics import roc_curve, auc
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import timeit
import sklearn
import cv2
import sys
import glob
import struct
#
# from menpo.visualize import print_progress
# from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
# from prettytable import PrettyTable
# from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
import mxnet as mx
from tqdm import tqdm

class Embedding:
    def __init__(self, prefix, epoch, ctx_id=0):
        print('loading', prefix, epoch)
        ctx = mx.gpu(ctx_id)
        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        image_size = (112, 112)
        self.image_size = image_size
        model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
        model.bind(for_training=False, data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
        model.set_params(arg_params, aux_params)
        self.model = model

    def get(self, rimg):
        img = rimg#cv2.imread()
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # img_flip = np.fliplr(img)
        img = np.transpose(img, (2, 0, 1))  # 3*112*112, RGB
        # img_flip = np.transpose(img_flip, (2, 0, 1))
        input_blob = np.zeros((1, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
        input_blob[0] = img
        # input_blob[1] = img_flip
        data = mx.nd.array(input_blob)
        db = mx.io.DataBatch(data=(data,))
        self.model.forward(db, is_train=False)
        feat = self.model.get_outputs()[0].asnumpy()
        # feat = feat.reshape([-1, feat.shape[0]])#* feat.shape[1]]) #512 shape
        # feat = feat.flatten()
        return feat
        
def get_image_feature(fw,img_path, img_list_path, model_path, gpu_idd,mbatch):
    img_list = open(img_list_path)
    embedding = Embedding(model_path, mbatch, gpu_idd)
    files = img_list.readlines()[0:5000]
    img_feats = []
    for img_index, each_line in enumerate(((files))):
        img_name = os.path.join(img_path, each_line.strip().split()[0])
        img = cv2.imread(img_name)
        if img.shape[0]!=112 and img.shape[1]!=112:
            img=cv2.resize(img, (112, 112),interpolation=cv2.INTER_CUBIC)
        img_feats.append(embedding.get(img))
         ## save img feature
        txt = np.array(embedding.get(img)).astype(np.float32)
        for v in txt:
            fw.write(str(v) + " ")
        fw.write("\n")
        
    img_feats = np.array(img_feats).astype(np.float32)
    return img_feats
          


def get():
   
    count =0 
    model_path = "./model_best"
    
    gpu_id = 9
    embedding = Embedding(model_path, 0, gpu_id)  

    f = open("all_asian_feature.txt","w")
    
    fw_bin = open('all_asian_feature.data','wb')
    total = 100 #1100000
    fw_bin.write(struct.pack("i",total))
    
    f_lab = open("all_asian_label.txt","w")
    
    path="/nfs-data/shiyy/faces_glint/imgs"
    for file in tqdm(os.listdir(path)):
        label = file
        
        for imgname in (os.listdir(path+"/"+file)): #图片名字
            imgpath=path+"/"+file +"/"+imgname
            img = cv2.imread(imgpath)
            
            feat = np.array(embedding.get(img)).astype(np.float32)
            ## 保存二进制,一张图和多张图是一样的
            for i in range(len(feat)):  ##(1,512),二维numpy ,(batch,512)

                face_feature = feat[i]
                ff = struct.pack("512f",*face_feature)
                fw_bin.write(ff)
                
                f_lab.write(str(label)+"/"+imgname+" "+str(label)+"\n")  #2/30.jpg 2
                
                
            ##保存浮点数
            feat_one = feat.reshape([-1, feat.shape[0]])#* feat.shape[1]]) #512 shape
            feat_one = feat.flatten()
            for v in feat_one:
                f.write(str(v) + " ")
            f.write("\n")
            
            ## 如果是多张图片的特征,flat,一维之后进行切分,每512 个值,换行
            # n = 0
            # for v in txt:
                # n = n + 1
                # fw.write(str(v) + " ")
                # if n==512:
                    # fw.write("\n")
                    # n=0
            
            count+=len(feat)
            if count>=100:
                print("****************************")
                print (count)
                print ("提取特征结束")


get()
         #用数据加载器,读取批量数据
        data2 = FaceImageIter(  
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror, #true
          rand_resize          = False, #me add to differ resolution img
          mean                 = mean,
          cutoff               = config.data_cutoff,  #0
          color_jittering      = config.data_color,  #0
          images_filter        = config.data_images_filter, #0
      )
		data2.reset()
        data2_iter = iter(data2)
        for batch2 in data2_iter: 
            t_model.forward(batch2, is_train=False)  #
            t_feat = t_model.get_outputs() #
            ## [(batch,512)]
               
            feat = t_feat[0].asnumpy()
            # print ("feat shape:",feat.shape)#(batch,512)
            feat = feat.reshape([-1, feat.shape[0]])#* feat.shape[1]]) #512 shape
            feat = feat.flatten()
            txt = np.array(feat).astype(np.float32)
            ###浮点数特征
            count = 0
            for v in txt:
                count = count + 1
                fw.write(str(v) + " ")
                if count==512:
                    fw.write("\n")
                    count=0
            ## 二进制 特征保存
            for i in range(len(t_feat[0].asnumpy())):
                face_feature = t_feat[0].asnumpy()[i]
                ff = struct.pack("512f",*face_feature)
                fw_bin.write(ff)
                # fw_bin("512f",*face_feature) #  错误
                tt = np.array(batch2.label[0].asnumpy()).astype(np.int32)
                f_lab.write(str(tt[i])+"\n") 
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值