这里采用CNN模型(卷积神经网络)来进行MNIST数据集的分类识别
1 导入模块
首先,导入需要的模块
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import models, layers
import matplotlib.pyplot as plt
2 载入MNIST数据集
调用keras集成的mnist的load_data函数载入数据集
# load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# train_images: 60000*28*28, train_labels: 60000*1
# test_images: 10000*28*28, test_labels: 10000*1
# pre-process data, change data shape & type
train_input = train_images.reshape(60000,28,28,1)
train_input = train_input.astype('float32')/255
test_input = test_images.reshape(10000,28,28,1)
test_input = test_input.astype('float32')/255
train_output = keras.utils.to_categorical(train_labels)
test_output = keras.utils.to_categorical(test_labels)
3 构建模型
构建一个卷积神经网络,定义构建函数如下