一、代码中的数据集可以通过以下代码继续加载
import tensorflow as tf
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
二、代码运行环境
Tensorflow-gpu==2.4.0
Python==3.7
三、训练代码如下:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
train_image = train_image / 255
test_image = test_image / 255
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc']
)
model.fit(train_image, train_label, epochs=50)
model.save(r'model_data/model.h5')
四、验证代码如下:
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
train_image = train_image / 255
test_image = test_image / 255
model = tf.keras.models.load_model(r'model_data/model.h5')
model.evaluate(test_image, test_label)