android keras训练模型 固化h5模型 转化为tflite格式

1.相互转化代码 h5转tflite h5转pb
对于模型导入android中,

  1. pb格式需在tensorflow1.13.1以下版本中转化为pb 很有限制
  2. 对于TensorFlow2.0的代码 最好转为tflite格式 导入android中 可运行 即高版本需要转化为tflite后进行

1.keras 训练的模型h5转tflite

import tensorflow as tf
import pathlib

export_dir = '/content/drive/My Drive/colab/test/middleReport/myuse'
# tf.saved_model.save(model, export_dir)
tf.saved_model.save(my_model, export_dir)
#转换模型。
# converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter = tf.lite.TFLiteConverter.from_keras_model(my_model)
tflite_model = converter.convert()
# tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file = pathlib.Path('modelTest.tflite')
tflite_model_file.write_bytes(tflite_model)
  1. Python 导入ttflite模型进行预测
import tensorflow as tf
model_path = "modelTest.tflite"
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()


# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))
# 填装数据
xxx = X_test[0][:150].reshape(-1,150,6)
# interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
interpreter.set_tensor(input_details[0]['index'], xxx)
        
# 注意注意,我要调用模型了
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print('result:{}'.format(output_data))
print(np.argmax(output_data))

3.keras 训练的模型h5转pb
第一种

import sys 
from keras.models import load_model 
import tensorflow as tf
import os
import os.path as osp
# from keras import backend as K
from tensorflow.compat.v1.keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        # freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables_initializer()).difference(keep_var_names or [])) 
        output_names = output_names or [] 
        output_names += [v.op.name for v in tf.global_variables()] 
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值