Tensorflow学习-MNIST数据集CNN
CNN
①数据集导入,keras自带的下载或者从某盘提取点击获取数据集,提取码:45yf
#加载MNIST数据集
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data('E:/TensorFlow_mnist/MNIST_data/mnist.npz')
print(x_train.shape,type(x_train)) #60000张28*28的图片
print(y_train.shape,type(y_train)) #60000个标签
②图像和数据类型的转化
这里使用通道的方式进行数据类型转化,channels_first(batch,channels,height,width)
#数据处理:规范化
from keras import backend as K
img_rows,img_cols=28,28;
#channels_first(batch,channels,height,width)
if K.image_data_format()=='channels_first':
x_train = x_train.reshape(x_train.shape[0],1,img_rows,img_cols)
x_test = x_test.reshape(x_test.shap[0],1,img_rows,img_cols)
input_shape = (1,img_rows,img_cols)
else:
x_train =x_train.reshape(x_train.shape[0],img_rows,img_cols,1)
x_test=x_test.reshape(x_test.shape[0],img_rows,img_cols,1)
input_shape=(img_rows,img_cols,1)
print(x_train