mxnet加载模型并进行前向推断

3 篇文章 0 订阅
1 篇文章 0 订阅

mxnet是由华人为主的团队(陈天奇,王乃岩)开发的深度学习架构;主要开发语言是python,相比TensorFlow,其最大的特点是接口友好。

训练得到新的mxnet模型(.params是二进制参数文件,.json是文本网络结构文件)之后,拿模型来进行预测也是工程中重要的工作。这一过程核心代码为:

    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    print(sym)
    # print(arg_params)
    # print(aux_params)

    # 提取中间某层输出帖子特征层作为输出
    all_layers = sym.get_internals()
    print(all_layers)
    sym = all_layers['fc1_output']

    # 重建模型
    model = mx.mod.Module(symbol=sym, label_names=None)
    model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])
    model.set_params(arg_params, aux_params)

 其中,mx.model.load_checkpoint函数是加载模型,两个参数分别是模型文件名的前缀和epoch数目;中间三行以get_internals为核心,这三行的作用是提取某一层为特征输出层,尤其是在人脸识别领域,真正的特征层是交叉熵层前面的高维向量层;后面三行,是重建模型和将模型结构和参数绑定的过程。

一个基于已训练好的模型的进行前向推断的完整Python代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/30 19:27
# @Author  : wangmeng
# @File    : inference.py
"""
mxnet前向过程
"""
import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtuple

prefix = "model-IR-m1-softmax"
epoch = 481
# img1_path = "./data/039.jpg"
# img2_path = "./data/039.jpg"


def str_expansion(sstring):
    """
    将“数字型”字符扩展
    :param string: 
    :return: 
    """
    ssize  = len(sstring)
    if ssize == 1:
        sstring = "00" + sstring
    elif ssize == 2:
        sstring = "0" + sstring
    return sstring


def str_expansion_lfw(sstring):

    """
    :param sstring: 将“数字型”字符扩展,针对lfw数据集
    :return: 
    """
    ssize = len(sstring)
    if ssize == 1:
        sstring = "000" + sstring
    elif ssize == 2:
        sstring = "00" + sstring
    elif ssize == 3:
        sstring = "0" + sstring
    return sstring


def cos_similarity(x, y):
    """
    计算两个向量x, y 的余弦相似度,就是余弦
    :param x: 
    :param y: 
    :return: 
    """
    length = len(x)

    x_squre = 0
    y_squre = 0
    xy_inner_product = 0

    for i in range(length):
        x_squre += x[i] * x[i]
        y_squre += y[i] * y[i]
        xy_inner_product += x[i] * y[i]
    print(x_squre)
    print(y_squre)
    return xy_inner_product /(math.sqrt(x_squre) * math.sqrt(y_squre))


def single_input(path):
    """
    给出图片路径,生成mxnet预测所需固定格式的输入
    :param path: 图片路径
    :return: mxnet所需格式的数据
    """
    img = cv2.imread(path)
    # mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (112, 112))

    # 重塑数组的形态,从(图片高度, 图片宽度, 3)重塑为(3, 图片高度, 图片宽度)
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)

    # 添加一个第四维度并构建NDArray
    img = img[np.newaxis, :]
    array = mx.nd.array(img)
    # print("单张图片输入尺寸:", array.shape)
    return array


if __name__ == "__main__":
    time_start = time.time()

    # verication_folder = "E:/face_detection/verification/IRimg_verification"
    verication_folder = "E:/face_detection/LFWtest/lfw_112_112"
    pair_file = "pairs.txt"
    # file = os.path.join(verication_folder, pair_file)
    file = "E:\\face_detection\\LFWtest\\pairs.txt"

    pair_list = []

    with open(file, 'r') as f:
        pair_list = f.readlines()

    time0 = time.time()
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    print(sym)
    # print(arg_params)
    # print(aux_params)

    # 提取中间某层输出帖子特征层作为输出
    all_layers = sym.get_internals()
    print(all_layers)
    sym = all_layers['fc1_output']

    # 重建模型
    model = mx.mod.Module(symbol=sym, label_names=None)
    model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])
    model.set_params(arg_params, aux_params)

    time1 = time.time()

    time_load = time1 - time0
    # print("模型加载和重建时间:{0}".format(time1 - time0))

    Batch = namedtuple("batch", ['data'])

    threshold = 0.6
    TP = 0
    TN = 0
    NUM_IR = 3280 / 2
    NUM_lfw = 6000 / 2

    time_frame = 0

    for item in pair_list:
        line = item.rstrip().split("\t")
        print(line)

        # 属于同一个人的图片对验证
        if len(line) == 3:
            time2 = time.time()
            folder = line[0]
            # img1 = str_expansion(str(int(line[1])-1)) + ".jpg"
            # img2 = str_expansion(str(int(line[2])-1)) + ".jpg"

            img1 = folder + '_' + str_expansion_lfw(line[1]) + ".jpg"
            img2 = folder + '_' + str_expansion_lfw(line[2]) + ".jpg"

            img1_path = os.path.join(verication_folder, folder, img1)
            img2_path = os.path.join(verication_folder, folder, img2)

            array1 = single_input(img1_path)
            array2 = single_input(img2_path)

            model.forward(Batch([array1]))
            vector1 = model.get_outputs()[0].asnumpy()
            vector1 = np.squeeze(vector1)

            model.forward(Batch([array2]))
            vector2 = model.get_outputs()[0].asnumpy()
            vector2 = np.squeeze(vector2)

            similarity = cos_similarity(vector1, vector2)
            time3 = time.time()
            time_frame = time3 - time2 + time_frame
            print(similarity, "\n")
            if similarity >= threshold:
                TP += 1

        # 属于不同的人的图片对验证
        if len(line) == 4:
            time4 = time.time()
            folder1 = line[0]
            # img1 = str_expansion(str(int(line[1])-1)) + ".jpg"
            img1 = folder1 + "_" + str_expansion_lfw(line[1]) + ".jpg"
            folder2 = line[2]
            # img2 = str_expansion(str(int(line[3])-1)) + ".jpg"
            img2 = folder2 + "_" + str_expansion_lfw(line[3]) + ".jpg"

            img1_path = os.path.join(verication_folder, folder1, img1)
            img2_path = os.path.join(verication_folder, folder2, img2)

            array1 = single_input(img1_path)
            array2 = single_input(img2_path)

            model.forward(Batch([array1]))
            vector1 = model.get_outputs()[0].asnumpy()
            vector1 = np.squeeze(vector1)

            model.forward(Batch([array2]))
            vector2 = model.get_outputs()[0].asnumpy()
            vector2 = np.squeeze(vector2)

            similarity = cos_similarity(vector1, vector2)
            time5 = time.time()
            time_frame = time5 - time4 + time_frame
            print(similarity, "\n")
            if similarity < threshold:
                TN += 1

    print("检真正确率:{0:.4f}".format(TP / NUM_lfw))
    print("拒假正确率:{0:.4f}".format(TN / NUM_lfw))
    print("模型加载时间: {0:.3f}s".format(time_load))
    print("检测一帧平均时间: {0:.3f}s".format(time_frame / (NUM_lfw*2)))

    time_end = time.time()

    print("程序运行时间: {0:.2f}min".format((time_end-time_start)/60))

    # exit()

 

  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值