1. 首先是数据集的下载和使用
下载地址: http://www.cs.toronto.edu/~kriz/cifar.html
下载完成后无需解压
直接调用语句即可读取数据集
#具体语句如下
cifar10 = tf.keras.datasets.cifar10 #使用内置API keras下载数据集(速度缓慢)
(x_train, y_train), (x_test, y_test) = cifar10.load_data() # 直接读取数据即可
2.构建CNN模型
CIFA_10数据集是由 60000
张RGB彩色图片构成
一共有10类分别为["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
其中50000
张图片为训练集,10000
张为测试集图片
在读取数据后我们可以输出测试集和训练集的数组大小
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
运行结果为
在送入卷积神经网络前先进行处理,使图片中的每个像素值都处于0~1之间
并转化为tf张量
x_train, x_test = tf.cast(x_train, dtype=tf.float32)/255.0, tf.cast(x_test, dtype=tf.float32)/255.0
y_train, y_test = tf.cast(y_train, dtype=tf.int32), tf.cast(y_test, dtype=tf.int32)
构建训练模型
# 构建Sequential模型
#建立模型 model
model = Sequential([
#特征提取层1
layers.Conv2D(16, kernel_size=(3,3), padding="same", activation=tf.nn.relu,input_shape=x_train.shape[1:]),
layers.Conv2D(16, kernel_size=(3,3), padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2,2)),
layers.Dropout(0.2),
#特征提取层2
layers.Conv2D(32, kernel_size=(3,3), padding="same", activation=tf.nn.relu),
layers.Conv2D(32, kernel_size=(3,3), padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2,2)),
layers.Dropout(0.2),
#全连接层
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(128,activation='relu'),
layers.Dropout(0.2),
layers.Dense(10,activation="softmax"),
])
结构如图所示
4. 配置训练方法
使用 model.complie
来配置训练的方法
# 配置方法
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
5.训练模型
history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
结果如图所示:
可见准确率达到了 65%
6.评估模型
对测试集使用 evaluate函数进行模型的评估 可以得到该模型参数对测试集的图片准确率
model.evaluate(x_test, y_test, verbose=2)
准确率为 68%
7.测试
从测试集中随机抽取几张图片进行识别
plt.figure(figsize=(10,10))
for i in range(4):
num = np.random.randint(1,10000)
plt.subplot(1,4,i+1)
plt.axis('off')
plt.imshow(x_test[num],cmap='gray')
demo = tf.reshape(x_test[num],(1,32,32,3))
y_pred = name[np.argmax(model.predict(demo))]
plt.title('Original: ' + name[(y_test.numpy())[num,0]] + '\nPredict: ' + y_pred)
plt.show()
可以看到此次的识别 对了两个错了两个