1.相互转化代码 h5转tflite h5转pb
对于模型导入android中,
- pb格式需在tensorflow1.13.1以下版本中转化为pb 很有限制
- 对于TensorFlow2.0的代码 最好转为tflite格式 导入android中 可运行 即高版本需要转化为tflite后进行
1.keras 训练的模型h5转tflite
import tensorflow as tf
import pathlib
export_dir = '/content/drive/My Drive/colab/test/middleReport/myuse'
# tf.saved_model.save(model, export_dir)
tf.saved_model.save(my_model, export_dir)
#转换模型。
# converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter = tf.lite.TFLiteConverter.from_keras_model(my_model)
tflite_model = converter.convert()
# tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file = pathlib.Path('modelTest.tflite')
tflite_model_file.write_bytes(tflite_model)
- Python 导入ttflite模型进行预测
import tensorflow as tf
model_path = "modelTest.tflite"
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))
# 填装数据
xxx = X_test[0][:150].reshape(-1,150,6)
# interpreter.set_tensor(input_details[0]['index'], image_np_expanded)
interpreter.set_tensor(input_details[0]['index'], xxx)
# 注意注意,我要调用模型了
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print('result:{}'.format(output_data))
print(np.argmax(output_data))
3.keras 训练的模型h5转pb
第一种
import sys
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
# from keras import backend as K
from tensorflow.compat.v1.keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
# freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables_initializer()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node: