第一步:准备数据
15种水果数据:'Apple_Braeburn', 'Banana', 'Blueberry', 'Cherry', 'Chestnut', 'Cocos', 'Corn', 'Eggplant','Fig', 'Ginger_Root', 'Granadilla', 'Lemon', 'Onion_Red', 'Orange', 'Pear'
,总共有7196张图片,每个文件夹单独放一种数据
第二步:搭建模型
本文选择一个简单cnn网络,其网络结构如下:
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 96, 96, 3) 0
_________________________________________________________________
conv0 (Conv2D) (None, 32, 32, 32) 896
_________________________________________________________________
dropout_1 (Dropout) (None, 32, 32, 32) 0
_________________________________________________________________
bn0 (BatchNormalization) (None, 32, 32, 32) 128
_________________________________________________________________
activation_1 (Activation) (None, 32, 32, 32) 0
_________________________________________________________________
max_pool_1 (MaxPooling2D) (None, 16, 16, 32) 0
_________________________________________________________________
conv1 (Conv2D) (None, 5, 5, 4) 1156
_________________________________________________________________
dropout_2 (Dropout) (None, 5, 5, 4) 0
_________________________________________________________________
bn1 (BatchNormalization) (None, 5, 5, 4) 16
_________________________________________________________________
activation_2 (Activation) (None, 5, 5, 4) 0
_________________________________________________________________
max_pool_2 (MaxPooling2D) (None, 2, 2, 4) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 16) 0
_________________________________________________________________
dense (Dense) (None, 15) 255
=================================================================
Total params: 2,451
Trainable params: 2,379
Non-trainable params: 72
第三步:训练代码
1)损失函数为:交叉熵损失函数
2)训练代码:
model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size, IMG_H, IMG_W),
steps_per_epoch=max(1, num_train // batch_size),
validation_data=generate_arrays_from_file(lines[num_train:], batch_size, IMG_H, IMG_W, flag=False),
validation_steps=max(1, num_val // batch_size),
epochs=100,
initial_epoch=0,
class_weight='auto',
callbacks=[checkpoint_period1, reduce_lr, early_stopping, csv_logger])
model.save_weights(log_dir + 'last1.h5')
第四步:统计正确率
HappyModel_model_logep042-accuracy0.966-val_accuracy0.999.h5正确率高达99.9%
第五步:搭建GUI界面
第六步:整个工程的内容
有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码,主要使用方法可以参考里面的“文档说明_必看.docx”
代码的下载路径(新窗口打开链接):基于keras框架的CNN神经网络水果种类识别分类系统源码
有问题可以私信或者留言,有问必答