python实现斑马线的识别
使用tensorflow.keras.preprocessing.image实现斑马线识别
导入库
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Flatten,Dense
import tensorflow as tf
from tensorflow.keras.models import Sequential,load_model
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import os
测试集和验证集的导入
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
train_dir = 'F:\\study\\data\\Zebra\\train\\'
validation_dir = 'F:\\study\\data\\Zebra\\test\\'
# path = 'F:\study\data\Zebra'
# train_dir = os.path.join(path, 'train')
# validation_dir = os.path.join(path, 'test')
将照片格式转换
train_datagen=ImageDataGenerator(rescale=1/255)
train_generator=train_datagen.flow_from_directory(train_dir,(150,150),batch_size=20,class_mode='binary',shuffle=True)
validation_datagen=ImageDataGenerator(rescale=1/255)
validation_generator=validation_datagen.flow_from_directory(validation_dir,(150,150),batch_size=20,class_mode='binary')
数据输出几张照片
train_datagen, _ = next(train_generator)
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
plotImages(train_datagen[:5])
构建CNN模型
model = Sequential()
model.add(Conv2D(128, (3, 3),activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])
for data_batch, labels_batch in train_generator:
print('data_shape:', data_batch.shape)
print('单张图像:\n', data_batch[0])
print('Batch:', labels_batch.shape, labels_batch)
break
模型训练
history = model.fit_generator(train_generator, epochs=20,
validation_data=validation_generator, validation_steps=50)
#model.save('F:\study\data\Zebra\zebra_crossing.h5')
print(history)
这里训练了20次
查看acc和loss
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.figure(num=1) # 正确率
plt.plot(epochs, acc, 'b', label='训练集acc',color='blue')
plt.plot(epochs, loss, 'b', label='训练集loss',color='red')
plt.legend()
#plt.savefig('loss_plain.png')
plt.show()
输出预测错误
counname=["zebra crossing","others"]
def error_found(data_x,data_y):
error=0
for i,j in enumerate(data_x):
if j!=data_y[i]:
error = data_batch[i]
print("真实值"+counname[data_y[i]]+"\n预测值"+counname[j])
return error
for data_batch, labels_batch in validation_generator:
print('data_shape:', data_batch.shape)
data = data_batch
data_x = labels_batch.astype('int').tolist()
break
data_y = model.predict_classes(data).flatten().tolist()
print(data_x)
print(data_y)
error = error_found(data_x,data_y)
try:
plt.title('预测错误图片')
plt.imshow(error)
except Exception:
print("没有找到预测错误图片,太完美了")