基于TensorFlow.keras的CNN架构实现mnist手写数据分类
1导入库
import tensorflow as tf
import numpy as np
import matplotlib. pyplot as plt
import pandas as pd
from tensorflow import keras
% matplotlib inline
2 读取MNIST数据集
mnist = keras. datasets. mnist
( x_train, y_train) , ( x_test, y_test) = mnist. load_data( )
3 重构数据至四维
x_train= x_train. reshape( x_train. shape+ ( 1 , ) )
x_test= x_test. reshape( x_test. shape+ ( 1 , ) )
x_train, x_test = x_train/ 255.0 , x_test/ 255.0
4 数据标签
label_train = keras. utils. to_categorical( y_train, 10 )
label_test = keras. utils. to_categorical( y_test, 10 )
5 模型构建
model = tf. keras. models. Sequential( [ tf. keras. layers. Conv2D( 64 , 7 , activation= "relu" , padding= "same" , input_shape= [ 28 , 28 , 1 ] ) ,
tf. keras. layers. MaxPooling2D( 2 ) ,
tf. keras. layers. Conv2D( 128 , 3 , activation= "relu" , padding= "same" ) ,
tf. keras. layers. Conv2D( 128 , 3 , activation= "relu" , padding= "same" ) ,
tf. keras. layers. MaxPooling2D( 2 ) ,
tf. keras. layers. Conv2D( 256 , 3 , activation= "relu" , padding= "same" ) ,
tf. keras. layers. Conv2D( 256 , 3 , activation= "relu" , padding= "same" ) ,
tf. keras. layers. MaxPooling2D( 2 ) ,
tf. keras. layers. Flatten( ) ,
tf. keras. layers. Dense( 128 , activation= "relu" ) ,
tf. keras. layers. Dropout( 0.5 ) ,
tf. keras. layers. Dense( 64 , activation= "relu" ) ,
tf. keras. layers. Dropout( 0.5 ) ,
tf. keras. layers. Dense( 10 , activation= "softmax" )
] )
6 模型显示
model. summary( )
7 使用SGD编译模型
model. compile ( optimizer= "sgd" ,
loss= "categorical_crossentropy" ,
metrics= [ "acc" ] )
8 学习20个纪元,使用20%数据交叉验证
history = model. fit( x_train, label_train, epochs= 20 , validation_split= 0.2 )
9 预测
y_pred = np. argmax( model. predict( x_test) , axis= 1 )
print ( "prediction accuracy: {}" . format ( sum ( y_pred== y_test) / len ( y_test) ) )
10 绘制结果
plt. plot( records. history[ 'loss' ] , label= 'training set loss' )
plt. plot( records. history[ 'val_loss' ] , label= 'validation set loss' )
plt. ylabel( 'categorical cross-entropy' ) ; plt. xlabel( 'epoch' )
plt. legend( )
11 模型训练精度及结果
11.1 精度
11.2 损失曲线