Keras卷积神经网络识别CIFAR-10图像
CIFAR-10数据集和前面所学的MNIST数据集一样都是图像识别数据集,但CIFAR-10 是包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
该书的第九章主要讲怎么导入CIFAR-10和查看数据集的相关属性,相关代码在第十章的中会呈现,这里只放导入数据集这一段,第九章完整的代码在下面链接里
#若第一次下载cifar10包,则需执行两次
(x_img_train,y_label_train),(x_img_test,y_label_test)=cifar10.load_data()
print('train:',len(x_img_train))
print('test:',len(x_img_test))
链接:https://pan.baidu.com/s/1Nqj1pf8UV_T0fWKHjVbAsw
提取码:yy7v
第十章用卷积神经网络识别CIFAR-10,实现过程与第八章类似
#数据预处理 chapter9
from keras.datasets import cifar10
import numpy as np
np.random.seed(10)
(x_img_train,y_label_train),(x_img_test,y_label_test)=cifar10.load_data()
print("train data:",'images:',x_img_train.shape,
" labels:",y_label_train.shape)
print("test data:",'images:',x_img_test.shape ,
" labels:",y_label_test.shape)
x_img_train_normalize = x_img_train.astype('float32') / 255.0
x_img_test_normalize = x_img_test.astype('float32') / 255.0
from keras.utils import np_utils
y_label_train_OneHot = np_utils.to_categorical(y_label_train)
y_label_test_OneHot = np_utils.to_categorical(y_label_test)
y_label_test_OneHot.shape
#建立模型
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D
model = Sequential()
model.add(Conv2D(filters=32,kernel_size=(3,3),
input_shape=(32, 32,3),
activation='relu',
padding='same'))
model.add(Dropout(rate=0.25))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=64,kernel_size=(3,3),
activation='relu',
padding='same'))
model.add(Dropout(rate=0.25))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dropout(rate=0.25))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(rate=0.25))
model.add(Dense(10, activation='softmax'))
print(model.summary())
#进行训练
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
import tensorflow as tf
import numpy as np
import keras
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.3
tf.keras.backend.set_session(tf.Session(config=config))
train_history=model.fit(x_img_train_normalize,y_label_train_OneHot,validation_split=0.2,
epochs=10,batch_size=300,verbose=2)
label_dict={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",
5:"dog",6:"frog",7:"horse",8:"ship",9:"truck"}
import matplotlib.pyplot as plt
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('train history')
plt.xlabel('Epoch')
plt.ylabel('train')
plt.legend(['train','validation'],loc='upper left')
plt.show()
show_train_history(train_history,'acc','val_acc')
show_train_history(train_history,'loss','val_loss')
#评估模型准确率
scores=model.evaluate(x_img_test_normalize,y_label_test_OneHot)
print('accuracy=',scores[1])
#进行预测
prediction=model.predict_classes(x_img_test_normalize)
def plot_images_labels_prediction(image,labels,prediction,idx,num=10): #num要显示的数据项数,默认为10,最大25
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25:num=25
for i in range(0,num):
ax=plt.subplot(5,5,1+i)
ax.imshow(image[idx],cmap='binary')
title="label="+str(labels[idx])
itle='label='+str(labels[idx])
if(len(prediction)>0):
title+=",prediction="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
idx+=1
plt.show()
plot_images_labels_prediction(x_img_test,y_label_test,prediction,0,10)
Predicted_Probability=model.predict(x_img_test_normalize)
def show_Predicted_Probability(y,prediction,
x_img,Predicted_Probability,i):
print('label:',label_dict[y[i][0]],
'predict:',label_dict[prediction[i]])
plt.figure(figsize=(2,2))
plt.imshow(np.reshape(x_img_test[i],(32, 32,3)))
plt.show()
for j in range(10):
print(label_dict[j]+ ' Probability:%1.9f'%(Predicted_Probability[i][j]))
show_Predicted_Probability(y_label_test,prediction,
x_img_test,Predicted_Probability,0)
#混淆矩阵
import pandas as pd
print(label_dict)
pd.crosstab(y_label_test.reshape(-1),prediction,
rownames=['label'],colnames=['predict'])
完整代码在下面链接
链接:https://pan.baidu.com/s/199EgdLhxRgDMHXfvLTLG6A
提取码:jj4f