keras简单神经网络搭建并训练测试

2 篇文章 0 订阅

在这里插入图片描述
原图和结果图

通过Keras搭建简单的神经网络,这里以minist数据集为例,测试手写字体训练效果,并进行一些简单的应用。

环境

在Windows下进行的测试,主要的安装包如下:

  • tensorflow_gpu==2.2.0
  • imutils==0.5.4
  • opencv_python==4.5.3.56
  • scikit_image==0.18.3
  • scikit_learn==0.24.2
  • numpy==1.21.2
  • py_sudoku==1.0.1

目录结构如下:
在这里插入图片描述

搭建网络

通过Keras来搭建几层简单网络,可以用TensorFlow里集成的Keras,或者单独安装Keras包。使用 MNIST 数据集来训练模型识别数字。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout

class MnistNet:
	@staticmethod
	def build(width, height, depth, classes):
		# 初始化模型
		model = Sequential()
		inputShape = (height, width, depth)

		# 从 CONV 到 RELU 到 POOL layers
		model.add(Conv2D(32, (5, 5), padding="same",
			input_shape=inputShape))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2)))

		# 再次从 CONV 到 RELU 到 POOL layers
		model.add(Conv2D(32, (3, 3), padding="same"))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2)))

		# FC层到relu层
		model.add(Flatten())
		model.add(Dense(64))
		model.add(Activation("relu"))
		model.add(Dropout(0.5))

		# 再次FC层到relu层
		model.add(Dense(64))
		model.add(Activation("relu"))
		model.add(Dropout(0.5))

		# 用softmax函数分类
		model.add(Dense(classes))
		model.add(Activation("softmax"))

		# 返回模型
		return model

训练网络


from mnistnet import MnistNet
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
import argparse

# 构造参数
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
	help="model output path")
args = vars(ap.parse_args())

# 设定学习率,迭代次数,送入网络批次大小
INIT_LR = 1e-3
EPOCHS = 16
Batch_Size = 160

# 获取MNIST dataset
print("[LOGS] Please wait...")
((trainData, trainLabels), (testData, testLabels)) = mnist.load_data()

# 训练数据和测试数据设定维度
print(trainData.shape[0])
trainData = trainData.reshape((trainData.shape[0], 28, 28, 1))
testData = testData.reshape((testData.shape[0], 28, 28, 1))

# 归一化0-1之间
trainData = trainData.astype("float32") / 255.0
testData = testData.astype("float32") / 255.0

# 标签转为向量
le = LabelBinarizer()
trainLabels = le.fit_transform(trainLabels)
testLabels = le.transform(testLabels)

# 初始化模型,如果只识别两类则loss = “ binary_crossentropy”
opt = Adam(lr=INIT_LR)
model = MnistNet.build(width=28, height=28, depth=1, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt,
	metrics=["accuracy"])
print("[LOGS] compiling model...")

# 训练
H = model.fit(
	trainData, trainLabels,
	validation_data=(testData, testLabels),
	batch_size=Batch_Size,
	epochs=EPOCHS,
	verbose=1)
print("[LOGS] training network...")

# 评估网络模型
predictions = model.predict(testData)
print("[LOGS] evaluating network...")
print(classification_report(
	testLabels.argmax(axis=1),
	predictions.argmax(axis=1),
	target_names=[str(x) for x in le.classes_]))

# 保存模型
model.save(args["model"], save_format="h5")

使用命令行输入来启动训练:

python train_classifier.py --model model/model_mnist.h5

在这里插入图片描述
等待训练完成,如下图示意:
在这里插入图片描述

测试效果

通过手写一些数字0-9来进行简单的测试。

from tensorflow.keras.models import load_model
import cv2
import imutils
from imutils.contours import sort_contours
import numpy as np
#获取图像
imgPath = "image/test3.jpg"
model_path = "model/model_mnist.h5"
is_show = True
# 获取视频
vs_img = cv2.imread(imgPath)
# 加载模型
model = load_model(model_path)
model.summary()

# 调整大小
frame = imutils.resize(vs_img,width=200)
# 调试过程中可以显示一下
# if Debug:
#     cv2.imshow("frame",frame)
#     cv2.waitKey(10)
# 转为灰度图
gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
# 滤波
bl = cv2.GaussianBlur(gray,(5,5),0)

# 边缘检测找轮廓
edge_canny = cv2.Canny(bl, 85, 200)

# 膨胀处理
kernel = np.ones((3,3),np.uint8)
edge_canny = cv2.dilate(edge_canny,kernel)


if is_show:
    cv2.imshow("edge_canny", edge_canny)
    cv2.waitKey(10)
items = cv2.findContours(edge_canny.copy(), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
# 返回conts中的countors(轮廓)
conts = items[0] if len(items) == 2 else items[1]
# print(conts)
# if is_show:
#     cv2.drawContours(frame, conts, -1, (0, 255, 0), 1)
#
#     cv2.imshow("src",frame)
#     cv2.waitKey(10)
# 从左到右排序
conts,_ = sort_contours(conts,method="left-to-right")
# print(conts)
# 初始化列表放找到的字符
find_chars = []

#遍历找字符
for i in conts:
    #print(np.array(i))
    (x,y,w,h) = cv2.boundingRect(i)
    # 过滤一下,找出字符边框
    if(w>2 and w< 100) and (h>5 and h< 100):
        # 框字符
        roi = gray[y:y+h,x:x+w]
        mask = np.zeros(roi.shape,dtype="uint8")
        digit = cv2.bitwise_and(roi, roi, mask=mask)
        # 自动阈值处理
        _, th = cv2.threshold(roi, 0 ,255,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
        # 宽高
        th_H,th_W = th.shape
        # 缩放到28尺寸
        if th_H < th_W:
            th = imutils.resize(th,width=28)
        else:
            th = imutils.resize(th,height=28)
        # if is_show:
        #     cv2.imshow("th", th)
        #     cv2.waitKey(10)
        # 缩放后的宽高
        th_H, th_W = th.shape
        dx = int(max(0,28-th_W)/2)
        dy = int(max(0,28-th_H)/2)

        # 填充到28x28
        padding = cv2.copyMakeBorder(th,top=dy,bottom=dy,left=dx,right=dx,
                                     borderType=cv2.BORDER_CONSTANT,value=(0,0,0))
        padding = cv2.resize(padding,(28,28))

        # 缩放到0-1,扩展维度
        padding = padding.astype("float32")/255.0
        padding = np.expand_dims(padding,axis=-1)

        #存入列表
        print(((x,y,w,h)))
        find_chars.append((padding,(x,y,w,h)))
    else:
        print("next ... ")
        continue

# 提取
boxes = [b[1] for b in find_chars]

find_chars = np.array([f[0] for f in find_chars], dtype="float32")
if find_chars is None:
    print("can not find chars ...")

# 放入模型
predicts = model.predict(find_chars)

# 标签
labels = "0123456789"
# 预测显示
for (pred, (x,y,w,h)) in zip(predicts,boxes):
    # 返回最大值
    p = np.argmax(pred)
    pre = pred[p]
    label = labels[p]
    # 绘制框显示
    cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
    cv2.putText(frame,label,(x-10,y-10),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
    cv2.imshow("result",frame)
    cv2.waitKey(10)

测试效果如下所示:
在这里插入图片描述

简单益智游戏应用

拿上面训练好的数字模型来识别数独板中的数字并解决数独填空。
流程如下:

  1. 输入一张待解谜的数独图像;
  2. 在图像中找到每个数字的位置;
  3. 给数独划分网格,一般是9x9,计算得到每个格子的位置;
  4. 判断格子中是否有数字,有的话就进行OCR识别;
  5. 用数独算法来解谜题;
  6. 结果输出显示

识别主要代码如下:


ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
	help="path to trained digit classifier")
ap.add_argument("-i", "--image", required=True,
	help="path to input sudoku puzzle image")
ap.add_argument("-d", "--is_show", type=int, default=-1,
	help="is show each step ")
args = vars(ap.parse_args())

# 加载模型
model = load_model(args["model"])
print("loading digit classifier...")
# 获取图像
image = cv2.imread(args["image"])
print("processing image...")
if image is None:
	print("could not load image ...")
# 调整大小
image = imutils.resize(image, width=400)
src = image.copy()
gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
# 9x9 数独格板
board_9 = np.zeros((9, 9), dtype="int")

# 宽度和高度方向上单个小方格尺寸
stepX = gray.shape[1] // 9
stepY = gray.shape[0] // 9

# 存每个小格子位置
each_loc = []

# 获取格子位置
for y in range(0, 9):
	# 当前格子位置
	c_row = []

	for x in range(0, 9):
		# 当前格子坐标
		startX = x * stepX
		startY = y * stepY
		endX = (x + 1) * stepX
		endY = (y + 1) * stepY
		# 存下来
		c_row.append((startX, startY, endX, endY))

		# 拿到小格子,并提取数字
		grid_img = gray[startY:endY, startX:endX]
		number = extract_number(grid_img, is_show=False)

		# 判断一下
		if number is not None:
			two_h = np.hstack([grid_img, number])
			# cv2.imshow("grid_img/number", two_h)

			# 将格子图缩放到28x28
			roi = cv2.resize(number, (28, 28))
			roi = roi.astype("float") / 255.0
			roi = img_to_array(roi)
			roi = np.expand_dims(roi, axis=0)

			# 预测格子里的数字
			pred = model.predict(roi).argmax(axis=1)[0]
			board_9[y, x] = pred

	# 放入列表
	each_loc.append(c_row)

# 数独板并显示
print("OCR sudoku board:")
makeup = Sudoku(3, 3, board=board_9.tolist())
makeup.show()

# 计算填写
print("solving sudoku makeup...")
solution = makeup.solve()
solution.show_full()

# 遍历每个格子
for (grid, b) in zip(each_loc, solution.board):

	for (box, n) in zip(grid, b):
		# 坐标位置
		startX, startY, endX, endY = box

		# 显示信息
		textX = int((endX - startX) * 0.3)
		textY = int((endY - startY) * -0.25)
		textX += startX
		textY += endY

		cv2.putText(src, str(n), (textX, textY),
			cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)

cv2.imshow("Results", src)
cv2.waitKey(0)
cv2.imwrite("output/res.jpg",src)

测试结果:
绿色为识别后解出来数字。
在这里插入图片描述

代码

完整代码:
https://github.com/ssggle/keras_mnistnet

Reference

https://keras.io/examples/
http://yann.lecun.com/exdb/mnist/

  • 4
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

圆滚熊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值