入门级别案例mnist手写数字识别
一、用keras实现手写识别
其中也没有什么需要着重注意的点,把几个易错点罗列一下:
- (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
左边要用两个元组接 - from keras.utils import to_categorical
用来处理样本标签,将其转换为one-hot编码
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels) - 源码附上
from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical
# 加载并观察
def load_data():
"""
导入数据并观察
:return:
"""
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print('训练集样本形状:\n', train_images.shape) # (60000, 28, 28)
print('训练集标签形状:\n', train_labels, train_labels.shape) # [5 0 4 ... 5 6 8] (60000,)
print('---------------------------------')
print('测试集样本形状:\n', test_images.shape) # (10000, 28, 28)
print('测试集标签形状:\n', test_labels, test_labels.shape) # [7 2 1 ... 4 5 6] (10000,)
return train_images, train_labels, test_images, test_labels
def model():
"""
建立神经网络模型
:return:
"""
# 建立顺序模型
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax')) # softmax返回10个概率值,总和为1的数组
network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
return network
def data_processing(train_images, train_labels, test_images, test_labels):
"""
准备图像数据和标签
:return:
"""
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32') / 255
# 准备标签,转换成one-hot编码
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
print('转换效果为:\n', train_labels)
return train_images, train_labels, test_images, test_labels
if __name__ == "__main__":
train_images, train_labels, test_images, test_labels = load_data()
train_images, train_labels, test_images, test_labels = data_processing(train_images, train_labels, test_images, test_labels)
network = model()
# 训练模型
network.fit(train_images, train_labels, epochs=5, batch_size=128) # epochs大约是线程的意思
# 显示精度
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print('测试损失值:\n', test_loss)
print('测试准确率:\n', test_accuracy)
疑问
- 自动训练了一万步,并没有自行指定
以后解决
一、用keras实现手写识别