代码如下所示:
# -*- coding: utf-8 -*-
# @Time : 2018/4/4 13:22
# @Author : mgliu
# @FileName: mnist.py
# @Software: PyCharm Community Edition
# -*- coding: utf-8 -*-
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten
from keras.layers.convolutional import Conv2D,MaxPooling2D
(X_train,y_train),(X_test,y_test)=mnist.load_data()
print(X_train.shape)
print(y_train[0])
X_train=X_train.reshape(X_train.shape[0],28,28,1).astype('float32')
X_test=X_test.reshape(X_test.shape[0],28,28,1).astype('float')
X_train/=255
X_test/=255
#把标签用one-hot 从新编码
def tran_y(y):
y_ohe=np.zeros(10)
y_ohe[y]=1
return y_ohe
y_train_ohe=np.array([tran_y(y_train[i])for i in range(len(y_train))])
y_test_ohe=np.array([tran_y(y_test[i])for i in