一、代码中的数据集合可以通过以下代码进行加载
import tensorflow as tf
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
二、代码运行环境
Tensorflow-gpu==2.4.0
Python==3.7
三、训练代码如下:
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
my_input = tf.keras.Input(shape=(28, 28))
x = tf.keras.layers.Flatten()(my_input)
x = tf.keras.layers.Dense(32, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(64)(x)
my_output = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=my_input, outputs=my_output)
model.summary()
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(train_images, train_labels, epochs=100, validation_data=(test_images, test_labels))
model.save(r'model/my_model.h5')
四、预测代码实现如下:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
model = tf.keras.models.load_model(r'model/my_model.h5')
pre = model.predict(test_images)
my_predict = np.argmax(pre[0])
plt.imshow(test_images[0])
plt.show()
print(my_predict)