首先,让我们来导入必要的库和模块:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, DepthwiseConv2D, BatchNormalization, ReLU
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report
创建一个输入层来设置图像形状:
inputs = Input(shape=(224,224,3))
接下来,我们可以通过实例化一个MobileNetV2模型来创建基础网络:
base = tf.keras.applications.MobileNetV2(input_tensor=inputs, include_top=False, weights="imagenet")
在这个基本的MobileNetV2模型之上,我们可以添加我们自己的分类层,该层包括一个全局平均池化层,一个密集层和dropout层:
x = base.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024,activation='relu')(x)
x = Dropout(0.25)(x)
我们的分类任务是对10种不同的水果和蔬菜进行分类,因此我们需要一个带有10个类别的密集层作为输出层:
predictions = Dense(10, activation='softmax')(x)
最后我们要创建模型并指定优化器和损失函数:
model = Model(inputs=base.input, outputs=predictions)
# 使用Adam优化器,设置学习率和损失函数
model.compile(optimizer = Adam(lr=0.0001),
loss = 'categorical_crossentropy',
metrics = ['accuracy'])
在准备训练网络之前,我们还需要进行数据增强和准备图像生成器。这里我们选择使用ImageDataGenerator模块,同时进行一些数据增强操作,如随机旋转、剪切和缩放等。
# 设定图片尺寸和batch size
img_size = (224,224)
batch_size = 32
# 设置ImageDataGenerator,进行数据增强
train_data_generator = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.15)
train_datagen = train_data_generator.flow_from_directory('./fruits-and-vegetables/train',
target_size=img_size,
shuffle=True,
batch_size=batch_size,
class_mode='categorical',
subset='training')
val_datagen = train_data_generator.flow_from_directory('./fruits-and-vegetables/train',
target_size=img_size,
shuffle=True,
batch_size=batch_size,
class_mode='categorical',
subset='validation')
然后就可以开始训练我们的模型啦:
# 训练模型并保存权重
epochs = 20
history = model.fit(train_datagen,
steps_per_epoch = train_datagen.samples // batch_size,
validation_data = val_datagen,
validation_steps = val_datagen.samples // batch_size,
epochs = epochs)
model.save_weights('fruits_and_vegetables_classification.h5')
最后,我们可以对测试集进行预测,并打印出分类性能报告:
# 加载测试数据
test_data_generator = ImageDataGenerator(rescale = 1./255)
test_datagen = test_data_generator.flow_from_directory('./fruits-and-vegetables/test',
target_size=img_size,
shuffle=False,
batch_size=batch_size,
class_mode='categorical')
# 预测测试数据
pred = model.predict(test_datagen)
# 计算分类性能
test_datagen.reset()
predIdxs = np.argmax(pred, axis=1)
print(classification_report(test_datagen.classes, predIdxs, target_names=test_datagen.class_indices.keys()))