Keras实现CIFAR-10图片分类

声明:该博客所有博文均为个人学习过程问题记录,仅供学术交流,非商用。如部分不小心侵犯了其版权,还望海涵,请及时联系本人删除或修改。

一、学习视频

在看完曹健老师用五个经典网络实现CIFAR-10图片分类后,无意中又看到了这个UP主的教学视频,讲得很仔细,按照其教学再次进行了实践,链接如下:https://www.bilibili.com/video/BV1E7411q7wW?t=410

二、学习代码

import os
import shutil

from keras import Input, Model
from keras.datasets import cifar10
from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
from keras.optimizers import RMSprop
from keras.utils import to_categorical, plot_model
from matplotlib import pyplot as plt

epochs = 5
batch_size = 32 #批大小
opt = RMSprop(lr=0.0001,decay=1e-6) #使用RMSprop优化器
num_classes =10  #有多少种类别
input_shape = (32, 32, 3) #图片的shape
output_dir = './output' #输出目录

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    print('1')
os.makedirs(output_dir)
print('2')

#keras提供的在线下载数据集的方法load_data()
(x_train, y_train),(x_val, y_val) = cifar10.load_data()

# 归一化,将像素值转化到0-1之间
x_train = x_train.astype('float32')/255.0
x_val = x_val.astype('float32')/255.0

#将类别向量转换为二进制(只有0和1)的矩阵类型表示,独热编码
y_train = to_categorical(y_train, num_classes)
y_val = to_categorical(y_val, num_classes)

#创建模型
input = Input(shape=input_shape)
x = Conv2D(filters=32, kernel_size=(3, 3),activation='relu', padding='same')(input)
x = Conv2D(filters=32, kernel_size=(3, 3),activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(rate=0.25)(x) #防止过拟合,提高模型泛化能力

x = Conv2D(filters=64, kernel_size=(3, 3),activation='relu', padding='same')(x)
x = Conv2D(filters=64, kernel_size=(3, 3),activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(rate=0.25)(x)

x = Flatten()(x)
x = Dense(units=512, activation='relu')(x)
x = Dense(units=num_classes, activation='softmax')(x)

model = Model(inputs = input, outputs = x, name='Cifar_Model') #创建模型
model.summary() #将模型参数信息打印到控制台
model_img = output_dir + 'cifar10_cnn.png' #模型结构图保存路径
plot_model(model, to_file=model_img, show_shapes=True) #将模型结构保存为一张图片
print('3')

#编译模型
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy']) #使用交叉熵计算损失

#训练模型
model.fit(x_train, y_train,  #为模型提供训练数据
          epochs=epochs,
          batch_size=batch_size,
          validation_data=(x_val, y_val),
          shuffle=True)  #混洗数据

#保存模型
model_path = output_dir + '/keras_cifar10_trained_model.h5'
model.save(model_path)
print('4')

#验证验证集里的图片
name = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
n = 20 #取多少张图片
x_val=x_val[:n]
y_val=y_val[:n]

#预测
y_predict = model.predict(x_val,batch_size=n)

#绘制预测结果
plt.figure(figsize=(18,3))  #指定画布大小
for i in range(n):
    plt.subplot(2,10,i+1)  #2行10列
    plt.axis('off')  #取消x,y轴坐标
    plt.imshow(x_val[i])  #显示图片
    if y_val[i].argmax() == y_predict[i].argmax():
        #预测正确,用绿色标题
        plt.title('%s,%s' %(name[y_val[i].argmax()], name[y_predict[i].argmax()]), color='green')
    else:
        # 预测错误,用红色标题
        plt.title('%s,%s' % (name[y_val[i].argmax()], name[y_predict[i].argmax()]), color='red')
plt.show() #显示画布

predict_img = output_dir + '/predict.png'
plt.savefig(predict_img)  #保存预测图片

三、运行结果

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值