import tensorflow as tf
# from my_tf_cn_1 import my_model
import numpy as np
import matplotlib.pyplot as plt
# load model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPool2D((2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.summary()
# load weights
latest = tf.train.latest_checkpoint('cnn/ckpt')
model.load_weights(latest)
# load data
(xtrain, ytrain),(xtest, ytest) = tf.keras.datasets.mnist.load_data()
xtrain, xtest = xtrain/255.0, xtest/255.0
# 预测对象
x = np.array(xtrain[:50]).reshape((-1,28,28,1))
y = model.predict(x)
print(np.argmax(y,axis=1))
print(ytrain[:50])
plt.imshow(xtrain[0],cmap='gray')
plt.show()
SDUWH2019-2020寒假python实训--my_tf_cn_2
最新推荐文章于 2024-05-15 22:12:36 发布