与时俱进,tensoflow 已经到2.3 了,是时候学习keras了。
习惯于图结构,官网的demo看着不爽,拿来改写一下。
import tensorflow as tf
import numpy as np
import cv2
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_inputs = tf.keras.layers.Input(shape=(28, 28))
net = tf.keras.layers.Flatten(input_shape=(28, 28))(x_inputs) #压平 (None,784)
net = tf.keras.layers.Dense(128, activation='relu')(net)#全连接层 +relu
net = tf.keras.layers.Dropout(0.2)(net) #遗忘层
net = tf.keras.layers.Dense(10, activation="softmax")(net) #全连接层 +回归
model = tf.keras.models.Model(inputs=x_inputs, outputs=net)
train_loss = tf.keras.metrics.Mean(name='train_loss')
#计算准确度都封装好了
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
optimizer = tf.keras.optimizers.Adam()# 优化函数
while True:#这样训练效果不是很好,应该把数据打乱,分批次训练。有心的同学自己改写一下。
train_loss.reset_states()
train_accuracy.reset_states()
with tf.GradientTape() as tape:
predictions = model(x_train)#返回即预测,不错
# loss都换成在这里定义了
loss = tf.keras.losses.sparse_categorical_crossentropy(y_train, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(y_train, predictions)
print(train_loss.result(), train_accuracy.result())
if float(train_accuracy.result()) > 0.98:
break
model.save("./mnist.h5")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]# 据说量化模型很快...
tflite_model = converter.convert()
open("mnist.tflite", "wb").write(tflite_model)
拿个图测试一下试试
interpreter = tf.lite.Interpreter(
model_path="./mnist.tflite")
interpreter.allocate_tensors()
# 模型输入和输出细节
input_details = interpreter.get_input_details()
print(input_details)
output_details = interpreter.get_output_details()
print(output_details)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_test = x_test.astype(np.float32)#原来的类型是 float64
interpreter.set_tensor(input_details[0]['index'], [x_test[3]])
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
out = output_data[0].tolist()
print(out.index(max(out)))
cv2.imshow("1", x_test[3])
cv2.waitKey(0)
效果应该不错。