识物应用:通过训练好的网络,实现对自己手写的数字的识别
代码为:
from PIL import Image
import numpy as np
import tensorflow as tf
model_save_path = './checkpoint/mnist.ckpt'
#复现模型(构建一个与之前训练过的模型相同的网络)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
#读取模型参数
model.load_weights(model_save_path)
preNum = int(input("input the number of test pictures:"))
for i in range(preNum):
image_path = input("the path of test picture:")
img = Image.open(image_path)
img = img.resize((28, 28), Image.ANTIALIAS)
img_arr = np.array(img.convert('L'))
#因为输入的是白底黑字,而我们训练的是黑底白字,所以对像素进行处理
img_arr = 255 - img_arr
img_arr = img_arr / 255.0
print("img_arr:",img_arr.shape)
#将img_arr前填加一个维度1,因为模型的是默认有batch的,而我们这里没有分batch所以加一个一维
x_predict = img_arr[tf.newaxis, ...]
print("x_predict:",x_predict.shape)
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
print('\n')
tf.print(pred)