导入相应库文件
from tensorflow. keras. datasets import mnist, fashion_mnist
from tensorflow. keras. layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow. keras. models import Sequential, load_model
from tensorflow. keras import optimizers, losses, callbacks
import tensorflow as tf
下载数据集并进行 数据的划分
( x_train, y_train) , ( x_test, y_test) = fashion_mnist. load_data( )
x_train, x_test = x_train/ 255.0 , x_test/ 255.0
搭建模型
model = tf. keras. Sequential( [
tf. keras. layers. Conv2D( 32 , ( 3 , 3 ) , activation= 'relu' , input_shape= ( 28 , 28 , 1 ) ) ,
tf. keras. layers. MaxPooling2D( ( 2 , 2 ) ) ,
tf. keras. layers. Conv2D( 64 , ( 3 , 3 ) , activation= 'relu' ) ,
tf. keras. layers. MaxPooling2D( ( 2 , 2 ) ) ,
tf. keras. layers. Conv2D( 64 , ( 3 , 3 ) , activation= 'relu' ) ,
tf. keras. layers. Flatten( ) ,
tf. keras. layers. Dense( 64 , activation= 'relu' ) ,
tf. keras. layers. Dense( 10 , activation= 'softmax' )
] )
model. summary( )
定义损失函数 及优化器
model. compile ( optimizer= tf. keras. optimizers. Adam( ) ,
loss= tf. keras. losses. categorical_crossentropy,
metrics= [ 'acc' ] )
将训练集、测试集的x进行维度转换
x_test = x_test. reshape( - 1 , 28 , 28 , 1 )
x_train = x_train. reshape( - 1 , 28 , 28 , 1 )
将y进行one_hat 编码
y_train_one = tf. keras. utils. to_categorical( y_train)
y_test_one = tf. keras. utils. to_categorical( y_test)
训练模型
hist = model. fit( x_train, y_train_one, epochs= 20 )
训练结果:
评估模型
acc = hist. history[ 'acc' ]
loss = hist. history[ 'loss' ]
import matplotlib. pyplot as plt
plt. rcParams[ 'font.family' ] = 'SimHei'
plt. rcParams[ 'axes.unicode_minus' ] = False
T = len ( loss)
plt. plot( range ( T) , loss, label = '损失' )
plt. plot( range ( T) , acc, label= '准确率' )
plt. legend( )
plt. show( )
model. evaluate( x_test, y_test_one)
[0.4195633828639984, 0.9110000133514404]
预测数据并输出
pred = model. predict( x_test)
data = x_test[ : 100 ]
fig, axes = plt. subplots( 10 , 10 , figsize= ( 12 , 12 ) )
for i, ax in enumerate ( axes. flat) :
ax. imshow( data[ i] , cmap= 'gray' )
ax. axis( 'off' )
ax. set_title( f'预测: { np. argmax( pred[ i] ) } ' )
plt. tight_layout( )
plt. show( )