参考自: https://www.tensorflow.org/beta/guide/keras/functional
请参考https://www.tensorflow.org/lite/convert/index
tf2.0和tf1.x还是有点差异的,这里说的都是tf2.0
所谓的生成TFLite model也就是部署到mobile or IOT上
Device deployment
The TensorFlow Lite FlatBuffer
file is then deployed to a client device (e.g. mobile, embedded) and run locally using the TensorFlow Lite interpreter. This conversion process is shown in the diagram below:
Tools dependent
sudo apt install python-pydot python-pydot-ng graphviz
sudo pip install tf-nightly-2.0-preview
pip install tensorflow==2.0.0-alpha0(有问题)
1] create model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
inputs = keras.Input(shape=(784,), name='img')
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')
model.summary()
keras.utils.plot_model(model, 'my_first_model.png')
keras.utils.plot_model(model, 'my_first_model_with_shape_info.png', show_shapes=True)
2] Training, evaluation, and inference
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss='sparse_categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(),
metrics=['accuracy'])
history = model.fit(x_train, y_train,
batch_size=64,
epochs=5,
validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])
3] Saving and serialization
model.save('path_to_my_model.h5')
del model
# Recreate the exact same model purely from the file:
model = keras.models.load_model('path_to_my_model.h5')
4] Convert the model. //没有第3步也可以到第4步?是的
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
5] TensorFlow Lite Python interpreter
//TensorFlow Lite Python interpreter: Using the interpreter from a model file
import numpy as np
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)