MobileFaceNet:从训练到Tflite

本教程是基于MobileFaceNet_TF项目之后的

1. 得到了ckpt文件后,如何转为tflite?

  1. 1 修改MobileFaceNet.py中batch_norm_params中的is_training&trainable都为false
  2. 2 自己写一个模型冻结脚本,代码如下:
from losses.face_losses import insightface_loss, cosineface_loss, combine_loss
from utils.data_process import parse_function, load_data
from nets.MobileFaceNet import inference
# from losses.face_losses import cos_loss
from verification import evaluate
from scipy.optimize import brentq
from utils.common import train
from scipy import interpolate
from datetime import datetime
from sklearn import metrics
import tensorflow as tf
import numpy as np
import argparse
import time
import os
from tensorflow.python.framework import graph_util

slim = tf.contrib.slim


def get_parser():
      parser = argparse.ArgumentParser(description='parameters to train net')
      parser.add_argument('--pretrained_model', type=str, default="/home/ubuntu/Project/MobileFaceNet_TF-master/arch/pretrained_model",
                          help='Load a pretrained model before training starts.')
      parser.add_argument('--output_file', type=str,default="/home/ubuntu/Project/MobileFaceNet_TF-master/arch/pretrained_model/saved_model.pb", help='Filename for the exported graphdef protobuf (.pb)')

      args = parser.parse_args()
      return args


def freeze_graph_def(sess, input_graph_def, output_node_names):
    for node in input_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    # Get the list of important nodes
    whitelist_names = []
    for node in input_graph_def.node:
        if (node.name.startswith('MobileFaceNet') or node.name.startswith('embeddings')):
            whitelist_names.append(node.name)

    # Replace all the variables in the graph with constants of the same values
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_whitelist=whitelist_names)
    return output_graph_def


if __name__ == '__main__':
    with tf.Graph().as_default():
        args = get_parser()

        # define placeholder
        inputs = tf.placeholder(name='img_inputs', shape=[None, 112, 112, 3], dtype=tf.float32)
        labels = tf.placeholder(name='img_labels', shape=[None, ], dtype=tf.int64)
        phase_train_placeholder = tf.placeholder_with_default(tf.constant(False, dtype=tf.bool), shape=None,
                                                              name='phase_train')

        # pretrained model path
        pretrained_model = None
        if args.pretrained_model:
            pretrained_model = os.path.expanduser(args.pretrained_model)
            print('Pre-trained model: %s' % pretrained_model)

        # identity the input, for inference
        inputs = tf.identity(inputs, 'input')

        prelogits, net_points = inference(inputs, bottleneck_layer_size=192, phase_train=False, weight_decay=5e-5)

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        sess = tf.Session()

        # saver to load pretrained model or save model
        # MobileFaceNet_vars = [v for v in tf.trainable_variables() if v.name.startswith('MobileFaceNet')]
        saver = tf.train.Saver(tf.trainable_variables())

        # init all variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # load pretrained model
        if pretrained_model:
            print('Restoring pretrained model: %s' % pretrained_model)
            ckpt = tf.train.get_checkpoint_state(pretrained_model)
            print(ckpt)
            saver.restore(sess, ckpt.model_checkpoint_path)

        # Retrieve the protobuf graph definition and fix the batch norm nodes
        input_graph_def = sess.graph.as_graph_def()

        # Freeze the graph def
        output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings')

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_file, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph: %s" % (len(output_graph_def.node), args.output_file))
  • 在get_parser中需要定义ckpt所在文件夹和将要输出文件的全路径

  • 在inference(inputs, bottleneck_layer_size=128, phase_train=False, weight_decay=5e-5) 中

    • inputs:指定输出节点名称,在train_nets.py中可以查看_
    • bottleneck_layer_size要设置为训练模型时输出维度(猜测),可在train_nets.py中embedding_size查看
    • phase_train实际上就是将1.1中batch_norm_params的is_training设为了false
    • weight_decay:在train_nets.py也可查看
    • 在train_nets.py中查看:找到inference函数看看传入的值是啥就知道了

运行脚本后,就会得到.pb文件

2. pb转tflite

  1. 环境tensorflow-gpu=1.15(训练和转换环境都是),所以使用tflite_convert命令行工具。该工具嵌入在tensorflow1.x>中,可以尝试在命令行中使用“tflite_convert”命令,看是否存在该命令,一般来说如果你正常。

    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md

  2. 如何转换,在命令行中代码如下:

    tflite_convert 
    --output_file /home/ubuntu/Project/MobileFaceNet_TF-master/arch/pretrained_model/saved_model.tflite 
    --graph_def_file /home/ubuntu/Project/MobileFaceNet_TF-master/arch/pretrained_model/saved_model.pb 
    --input_arrays input 
    --input_shapes 1,112,112,3 
    --output_arrays embeddings 
    --output_format TFLITE
    
    

–output_file:定义了输出tflite的全路径

–graph_def:定义了pb文件路径

–input_arrays:指明输入节点

–input_shapes:指明输入shape

output_arrays:指明输出节点

–output_format:指明转换格式

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值