项目概述
把一堆苹果、香蕉的图片交给模型训练,让它判断这张图是苹果还是香蕉
说明: 环境要用到Anaconda 就是要安装下面的包,本博客只写了数据封装和模型搭建的代码。
数据集封装
import os
import cv2
import numpy as np
# 数据集封装
def getFileList(dir, Filelist, ext=None):
"""
来源 csdn 钱彬 (Qian Bin)
获取文件夹及其子文件夹中文件列表
输入 dir:文件夹根目录
输入 ext: 扩展名
返回: 文件路径列表
"""
newDir = dir
if os.path.isfile(dir):
if ext is None:
Filelist.append(dir)
else:
if ext in dir[-3:]:
Filelist.append(dir)
elif os.path.isdir(dir):
for s in os.listdir(dir):
newDir = os.path.join(dir, s)
getFileList(newDir, Filelist, ext)
return Filelist
def get_train_test_data(path):
# 获取电脑里面的训练集样本,封装成数据
data=[[], []]
label=[[], []]
for i in range(0,2):
imglist = getFileList(path[i], [], "jpg")
print('本次执行检索到 ' + str(len(imglist)) + ' 张训练图像')
apple_list = [[], []]
banana_list = [[], []]
for imgpath in imglist:
imgname = os.path.splitext(os.path.basename(imgpath))[0]
img = cv2.imread(imgpath)
# 把彩图转化成灰度图
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = cv2.resize(img,(28,28))
# 添加到训练集
if imgname[:5] == "apple":
apple_list[i].append(img)
else:
banana_list[i].append(img)
# 标签:苹果用0表示 香蕉用1表示
apple_label = [0] * len(apple_list[i])
banana_label = [1] * len(banana_list[i])
data[i] = np.array(apple_list[i] + banana_list[i])
label[i] = np.array(apple_label + banana_label)
return data[0], label[0], data[1], label[1]
if __name__ == '__main__':
# 本次执行检索到 945 张训练图像
# 本次执行检索到 351 张训练图像
# dirlist = ["./fruit_data/Training", "./fruit_data/Test"]
# x_train, y_train, x_test, y_test = get_train_test_data(dirlist)
# print(x_train.shape) # (945, 28, 28)
# print(x_test.shape) # (351, 28, 28)
pass
模型搭建
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
# 数据集 需要自己封装
import get_fruit_data2
# 创建卷积神经网络模型 训练数据集
def load_datasets():
# 数据集返回的是 数据 和 标签
# x_train,y_train = get_fruit_data.get_train_test_data("./fruit_data/Training")
# x_test,y_test = get_fruit_data.get_train_test_data("./fruit_data/Test")
dirlist = ["./fruit_data/Training", "./fruit_data/Test"]
x_train, y_train, x_test, y_test = get_fruit_data2.get_train_test_data(dirlist)
# 图像: 张 行 列 通道 astype 归一化
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float') / 255
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float') / 255
# 对训练用标签进行one-hot编码
y_train = to_categorical(y_train, num_classes=10)
return x_train, x_test, y_train, y_test
def model_create():
# 创建神经网络模型
model = Sequential()
# 添加卷积层: 卷积核个数,卷积核大小,输入维度,激活函数
model.add(Conv2D(filters=32, kernel_size=(5,5), input_shape=(28,28,1), activation='relu'))
# 添加池化层
model.add(MaxPooling2D())
model.add(Conv2D(filters=64, kernel_size=(5,5), activation='relu'))
model.add(MaxPooling2D())
# 一维化
model.add(Flatten())
model.add(Dense(units=1024, activation='relu'))
model.add(Dense(units=512, activation='relu'))
model.add(Dense(units=10, activation='softmax'))
model.compile(optimizer='adam', loss='mse')
return model
def model_train(model, x_train, y_train):
print("Train..................")
model.fit(x_train, y_train, batch_size=20, epochs=2, verbose=1)
def model_test(model, x_test, y_test):
print("Test.................")
y_test = to_categorical(y_test, num_classes=10)
# 验证模型,得到损失值
loss = model.evaluate(x_test, y_test, batch_size=20, verbose=1)
print("test loss: ", loss)
if __name__ == '__main__':
# 加载数据集
x_train, x_test, y_train, y_test = load_datasets()
# 创建卷积神经网络
model = model_create()
# 训练
model_train(model, x_train, y_train)
# 验证
model_test(model, x_test, y_test)
# 预测
y_pre = model.predict_classes(x_test)
print("pre : ", y_pre)
print("true: ", y_test)
# 保存模型 .h5就是模型的后缀名
model.save("./fruit_train_model.h5")
测试模型
import cv2
from keras.models import load_model
# 测试模型
def test_fruit(fruit_path):
# 获取本地图片
img = cv2.imread(fruit_path)
# 把彩图转化成灰度图
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = cv2.resize(img, (28, 28))
# 把图片转化成训练集的shape
img = img.reshape(1, 28, 28, 1)
img = img.astype('float') / 255
# 加载模型
model = load_model("./fruit_train_model.h5")
# 预测
y_pre = model.predict_classes(img)
# print("Pre: ", y_pre)
return y_pre[0]
if __name__ == '__main__':
# 导入一张苹果的图片或者香蕉的图片
# y_pre = test_fruit("./fruit_data/Test/apple_00018.jpg")
# 0表示苹果 1表示香蕉
# print(y_pre)
pass