Tensorflow学习,Fashion数据集,包含7万张图片用于训练和测试,其中6万张用于训练,1万张用于测试。
##注意:
代码在运行时加载数据集可能会从外网下载,可能会导致下载失败中断运行。
给出两个方法:
1、直接从官网下载,附上教程链接:https://www.tensorflow.org/tutorials/keras/classification
2、条件允许的可以直接开启科学上网,就可以下载成功啦
最后附上完整代码:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
class FashionModel(Model):
def __init__(self):
super(FashionModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x, training=None, mask=None):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = FashionModel()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']
)
model.fit(x_train, y_train, batch_size=32, epochs=20, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
运行成功:
Epoch 18/20
1875/1875 [==============================] - 6s 3ms/step - loss: 0.4907 - sparse_categorical_accuracy: 0.8283 - val_loss: 0.5619 - val_sparse_categorical_accuracy: 0.8230
Epoch 19/20
1875/1875 [==============================] - 6s 3ms/step - loss: 0.4904 - sparse_categorical_accuracy: 0.8289 - val_loss: 0.5419 - val_sparse_categorical_accuracy: 0.8139
Epoch 20/20
1875/1875 [==============================] - 6s 3ms/step - loss: 0.4721 - sparse_categorical_accuracy: 0.8324 - val_loss: 0.6536 - val_sparse_categorical_accuracy: 0.7796
Model: "fashion_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
_________________________________________________________________
dense (Dense) multiple 100480
_________________________________________________________________
dense_1 (Dense) multiple 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
进程已结束,退出代码0