如何导入mnist数据集并且预处理
mnist是我们在学习机器学习或者深度学习时常用的数据集,以下是使tensorflow导入mnist的一种方法。
在深度学习训练中,为了提高执行效率,常进行矩阵操作,因此后面还对数据集进行的简单的预处理,方便使用。
以下函数可直接复制到你的代码中,调用即可。
def load_mnist():
import tensorflow.keras as keras
(x,y),(x_test,y_test) = keras.datasets.mnist.load_data()
x_train_flatten = x.reshape(x.shape[0], -1).T
y_train = y.reshape(y.shape[0], -1).T
x_test_flatten = x_test.reshape(x_test.shape[0], -1).T
y_test_flatten = y_test.reshape(y_test.shape[0], -1).T
return x_train_flatten,y_train, x_test_flatten, y_test_flatten