一. 数据集介绍
- Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
- Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)
二. 训练模型
- 实质也是一个线性回归模型,softmax得到的是各个类别的概率,这些概率总和为1,
- 训练过程跟之前一样,定义模型、定义损失函数、分类精度只是方便我们衡量检测效果,不参与训练过程的参数优化,
- 训练代码:
import tensorflow as tf
from IPython import display
from d2l import tensorflow as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10
W = tf.Variable(tf.random.normal(shape=(num_inputs, num_outputs), mean=0, stddev=0.01))
b = tf.Variable(tf.zeros(num_outputs))
def softmax(X):
X_exp = tf.exp(X)
partition = tf.reduce_sum(X_exp, 1, keepdims=True)
return X_exp / partition
def net(X):
return softmax(tf.matmul(tf.reshape(X, (-1, W.shape[0])), W) + b)
def cross_entropy(y_hat,