一、数据集介绍:MNIST数据集——来自美国国家标准与技术研究所。训练集(Training Set)由250个不同的人手写的数字构成,其中50%是高中学生,50%来自人口普查局(The Census Bureau)的工作人员。测试集(Testing Set)也是同样比例的手写数字数据。该案例的目的是区分0~9这10个数字。
二、建模步骤:
2.1导入MNIST数据
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets("data/MNIST/",one_hot=False)
2.2准备训练数据与测试数据
X0 = data.train.images
Y0 = data.train.labels
X1 = data.validation.images
Y1 = data.validation.labels
print(X0.shape)
2.3手写数字展示
from matplotlib import pyplot as plt
plt.figure()
fig,ax = plt.subplots(2,5)
ax=ax.flatten()
for i in range(10):
Im=X0[Y0==i][0].reshape(28,28)
ax[i].imshow(Im)
plt.show()
2.4产生one-hot型因变量
from keras.utils import to_categorical
YY0=to_categorical(Y0)
YY1=to_categorical(Y1)
2.5逻辑回归模型的构建
from keras.layers import Activation, Dense, Flatten, Input
from keras import Model
input_shape=(784,)
input_layer=Input(input_shape)
x=input_layer
x=Dense(10)(x)
x=Activation('softmax')(x)
output_layer=x
model=Model(input_layer,output_layer)
2.6模型编译
from keras.optimizers import Adam
model.compile(optimizer = Adam(0.01),
loss = 'categorical_crossentropy',
metrics = ['accuracy'])
2.7模型拟合
model.fit(X0,YY0,
validation_data=(X1,YY1),
batch_size=1000,
epochs=10)
2.8参数估计结果可视化
plt.figure()
fig,ax = plt.subplots(2,5)
ax=ax.flatten()
weights = model.layers[1].get_weights()[0]
for i in range(10):
Im=weights[:,i].reshape((28,28))
ax[i].imshow(Im,cmap='seismic')
ax[i].set_title("{}".format(i))
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.show()