pb graph鼠标移上显示数据_模型压缩及亿级数据实时推理

136dce12d37ecb440647d702404e4e13.png

本文记录并给出一个 进行模型压缩并实现 按天 亿级数据实时推理的简易例子。

1.蒸馏学习

关于蒸馏学习本专栏其他文章有写,这里不详细介绍。进行蒸馏的原因,就是要利用当前的大模型提升 我们要获得的小模型的效果。

这里要提一下,如果采用teacher-student联合训练模式,记得分别标记 teacher和student网络的各个节点。

执行过程中只保存 student网络的变量,推理模型也只调起恢复student网络变量

session.run(tf.global_variables_initializer())#
gvars = tf.trainable_variables()#tf.global_variables()
#增加一个saver 用于存储 所有变量 ,用于存储新模型变量
 save_var = [v for v in gvars if 'student' in v.name]#[v for v in gvars if v in variables_notto_restore_student]
saver = tf.train.Saver(save_var)#指定存储这部分变量


gvars = tf.trainable_variables()#tf.global_variables()
print(gvars)
#save_var = [v for v in gvars if 'teacher' not in v.name]
save_var = [v for v in gvars if 'student' in v.name]
saver = tf.train.Saver(save_var)
saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

2.tflite 量化压缩

2.1 模型转pb文件

记得在模型源码中标记一下起点和输出点

def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PBmodel path
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #
    # input_checkpoint = checkpoint.model_checkpoint_path #
    #output_node_names = "score_teacher/output_teacher"
    output_node_names = "score_student/output_student"#
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()#
    input_graph_def = graph.as_graph_def()#
    
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #
        output_graph_def = graph_util.convert_variables_to_constants(  # 
            sess=sess,
            input_graph_def=input_graph_def,# :sess.graph_def
            output_node_names=output_node_names.split(","),
            variable_names_whitelist=None,variable_names_blacklist=None)#
             
        with tf.gfile.GFile(output_graph, "wb") as f: #
            f.write(output_graph_def.SerializeToString()) #
        print("%d ops in the final graph." % len(output_graph_def.node))
#
input_checkpoint='./weaklyLearning_rnnw2v_transferMatrix_trace/compression_online/model/compress_model20200526_1/best_validation'
out_pb_path='./weaklyLearning_rnnw2v_transferMatrix_trace/compression_online/model/compress_model20200526_1/pbmodel.pb'
freeze_graph(input_checkpoint, out_pb_path)

2.2 量化压缩

# default 默认压缩  需要不低于1.14
#pip install --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow-gpu==1.14.1
in_tensors = ["input_x"]
out_tensors = ["score_student/output_student"]
graph_def_file = './weaklyLearning_rnnw2v_transferMatrix_trace/compression_online/model/compress_model20200526_1/pbmodel.pb'
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, in_tensors, out_tensors)
#converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
#converter.post_training_quantize = True
#converter.optimizations = [tf.lite.Optimize.DEFAULT]#tf.lite.Optimize下有DEFAULT,OPTIMIZE_FOR_LATENCY,OPTIMIZE_FOR_SIZE
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_LATENCY]#OPTIMIZE_FOR_SIZE
tflite_model = converter.convert()
open("./weaklyLearning_rnnw2v_transferMatrix_trace/compression_online/model/compress_model20200526_1/quantify_default_model.tflite", "wb").write(tflite_model)
#===============================================================================================#

3.PC端调用

给出一个可行的example

if __name__ == '__main__':
    #1.加载模型文件
    model_path = './weaklyLearning_rnnw2v_transferMatrix_trace/compression_online/model/compress_model20200526_1/quantify_default_model.tflite'
    interpreter = tf.lite.Interpreter(model_path=model_path)
    #2.创建tensors
    interpreter.allocate_tensors()#创建tensors
    #3.获取输入输出OP
    input_details = interpreter.get_input_details()
    print(str(input_details))#[{'name': 'input_x', 'index': 24, 'shape': array([ 1, 60], dtype=int32), 'dtype': <class 'numpy.int32'>, 'quantization': (0.0, 0)}]
    output_details = interpreter.get_output_details()
    print(str(output_details))
    #4.张量填充
    text = [i for i in range(60)]
    image_np_expanded = np.expand_dims(text, axis=0)
    image_np_expanded = image_np_expanded.astype('int32') # 类型也要满足要求
    model_interpreter_start_time = time.time()
    for i in range(100):
        interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
        #5.运行推理
        interpreter.invoke()#调起模型
        output_data = interpreter.get_tensor(output_details[0]['index'])
    #
    model_interpreter_end_time = time.time()
    print(model_interpreter_end_time-model_interpreter_start_time)

效果及时耗统计

1.pc端 依赖gpu

0.08818244934082031s / 100 = 88.18ms/100 = 0.88ms/条;

online实时推理:在PC段调用经过量化压缩的tflite模型,可实现天量级为上亿级别的实时推理

2.手机端 依赖CPU

平均3~4ms/条

相比大模型,准确率下降幅度可控制在3%以内(分类模型测试)。

参考

[1] https://www.tensorflow.org/lite/guide/python?hl=zh-cn

[2]https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/python/label_image.py

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值