cifar10数据集介绍
cifar10是一个图像识别的数据集,一共有6000张32x32像素的彩色图片(训练集:测试集=5:1);这些图片一共分为10类(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车)。
代码及解析
导入需要的库
import keras
import numpy as np
from keras. models import Sequential
from keras. layers import Conv2D, MaxPooling2D, Dense, Dropout, Flatten
from keras. datasets import cifar10
可选(控制GPU资源,动态申请显存)
import tensorflow as tf
config = tf. compat. v1. ConfigProto( gpu_options= tf. compat. v1. GPUOptions( allow_growth= True ) )
sess = tf. compat. v1. Session( config= config)
下载数据集
( x_train, y_train) , ( x_test, y_test) = cifar10. load_data( )
#keras提供了在线下载数据集的方法,使用load_data( ) 函数就可以在线下载,要保证网络连接正常,下载需要一定的时间,需要耐心的等待
数据集处理
x_train = x_train. astype( 'float32' ) / 255.0
x_test = x_test. astype( 'float32' ) / 255.0
from keras. utils import np_utils # 辅助函数库
y_train = np_utils. to_categorical( y_train)
y_test = np_utils. to_categorical( y_test)
建立网络
model = Sequential( )
model. add( Conv2D( filters= 32 ,
kernel_size= ( 3 , 3 ) ,
input_shape= ( 32 , 32 , 3 ) ,
activation= 'relu' ,
padding= 'same' ) )
model. add( Dropout( 0.25 ) )
model. add( MaxPooling2D( pool_size= ( 2 , 2 ) ) )
model. add( Conv2D( filters= 64 , kernel_size= ( 3 , 3 ) ,
activation= 'relu' , padding= 'same' ) )
model. add( Dropout( 0.25 ) )
model. add( MaxPooling2D( pool_size= ( 2 , 2 ) ) )
model. add( Conv2D( filters= 128 , kernel_size= ( 3 , 3 ) ,
activation= 'relu' , padding= 'same' ) )
model. add( Dropout( 0.25 ) )
model. add( MaxPooling2D( pool_size= ( 2 , 2 ) ) )
model. add( Flatten( ) )
model. add( Dense( 1024 , activation= 'relu' ) )
model. add( Dense( 10 , activation= 'softmax' ) )
model. compile ( loss= 'categorical_crossentropy' ,
optimizer= 'adam' , metrics= [ 'accuracy' ] ) # 使用交叉熵计算损失,使用Adam作优化方法
train_history = model. fit( x_train, y_train,
validation_split= 0.2 ,
epochs= 20 , batch_size= 128 , verbose= 1 )
测试
scores = model. evaluate( x_test, y_test, verbose= 0 )
print ( scores[ 1 ] )
输出
loss: 0.1013 - accuracy: 0.9640 - val_loss: 0.8371 - val_accuracy: 0.7604
0.7523000240325928
虽然只迭代了20次,虽然网络很浅,但是在训练集上的结果还是不错,但是 在验证集和测试集上的效果都不好,结果是高方差的,已经过拟合了。 不要慌,慢慢来,将来我们就会学会如何防止过拟合,如何解决高方差和高偏差的问题,如何改进模型,尝试新的比较流行的网络结构了。