【OpenCV】手写算式识别

OpenCV 机器学习库提供了一系列 SVM 函数和类来实现 SVM 模型的训练和预测,方便用户实现自己的 SVM 模型,并应用于分类问题。本文主要介绍使用 openCV 实现手写算式识别的工作原理与实现过程。

目录

1 SVM 模型

1.1 SVM 模型介绍

1.2 SVM 模型原理

2 手写算式识别

2.1 字符识别

2.2 算式识别


1 SVM 模型

1.1 SVM 模型介绍

        SVM 是支持向量机(Support Vector Machine)的英文缩写,是统计学习理论中一种重要的分类方法,其早期工作来自前苏联 Vladimir N. Vapnik 和 Alexander Y. Lerner 在1963年发表的研究。

        1995年,Corinna Cortes 和 Vapnik 提出了软边距的非线性 SVM 并将其应用于手写字符识别问题,为 SVM 在其他领域的应用提供了参考。

SVM 的优点主要包括:

​    1)具有较好的可解释性。SVM 的决策函数和支持向量清晰,易于理解。

​    2)适用性广泛。SVM 能够应用于多种数据类型和领域,如文本分类、图像识别和生物信息学等。

​    3)鲁棒性强。SVM 对训练数据中的噪声和异常点具有较强的容错能力,能有效处理输入数据中的噪声。

​    4)适合高维数据。通过核函数,SVM 能够将低维空间的非线性问题映射到高维空间,进行线性划分,从而解决复杂的非线性问题。

​    5)可控制的过拟合。通过调整正则化参数和松弛变量,SVM 可以控制模型的复杂度,有效避免过拟合问题。

​    6)避免陷入局部最优解。使用结构风险最小化原则,使得 SVM 能够更好地避免陷入局部最优解,并具有较低的泛化误差。

1.2 SVM 模型原理

        在二分类问题中,给定输入数据和学习目标: X=\{ X_1, X_2, ... , X_N\}y \in \{-1, 1\},若存在决策边界(decision boundary)

\omega ^T X + b = 0

将样本按类别分开,则称该分类问题是线性可分的(Linear Separable)。

        按照统计学习理论,分类器在经过学习新数据时会产生风险,风险的类型分为经验风险和结构风险:

式中 f 表示分类器,经验风险由损失函数定义,描述了分类器所给出的分类结果的准确程度;结构风险由分类器参数矩阵的范数定义,描述了分类器自身的复杂程度以及稳定程度。

        复杂的分类器容易过拟合,因此是不稳定的。通过最小化经验风险和结构风险的线性组合以确定其模型参数:

式中 C 是正则化参数,当 p = 2 时,该式被称为 L_2 正则化。

​    对于线性可分问题,SVM 经验风险为 0,SVM 模型简化为最小化结构风险,由于点到超平面的距离反比于 || ω ||,因此模型可解释为最大化样本到超平面的最小距离,

即最优超平面距离给定的每个样本尽可能远。

2 手写算式识别

2.1 字符识别

        OpenCV 机器学习库提供了一系列 SVM 函数和类来实现 SVM 模型的训练和预测,可以很方便地实现用户自定义的分类模型。

使用 OpenCV 实现 SVM 模型的基本步骤如下:

    (1)创建模型。使用 cv2.ml.SVM_create() 创建 SVM 模型,使用 setKernel() 指定核函数;

    (2)初始化模型参数。使用 setC() 和 setGamma() 设置参数的初始值;

    (3)模型训练。使用 train() 函数,以及向量化的样本和分类标签,训练模型;

    (4)模型评估。使用 predict() 预测新样本,并统计正确率;

    (5)模型保存。使用 save() 保存模型,文件格式为 *.dat 。

        在手写算式的字符识别中,需要识别数字 0 ~ 9,以及 +,-,×,÷,(,)和 = 共 17 种字符。SVM 模型的输入样本是字符图像向量化的结果,处理步骤包括:

      1)图像缩放。将字符图像统一成 28 × 28 大小;

      2)颜色反转。使用 cv2.bitwise_not() 函数实现颜色反转,便于后续步骤;

      3)去偏斜。使用 cv2.moments() 计算图像的矩,然后使用 cv2.warpAffine() 去偏斜;

      4)向量化。将图像按照十字划分成 4 个区域,计算每个区域的方向梯度直方图,拼接成一个向量。

参考链接:OpenCV: OCR of Hand-written Data using SVM

2.2 算式识别

        手写算式识别包括 3 个阶段:字符分割、图像预处理和字符识别。字符分割用于提取输入图像中的连续字符,图像预处理用于字符图像的特征化,字符识别用于图像与字符的对应。最后按照顺序拼接识别到的字符,就得到输出表达式。

#-*- Coding: utf-8 -*-

import cv2
import numpy as np
import gradio as gr

# 加载模型
model = cv2.ml.SVM_load('./svm_data.dat')
chars = '0123456789+-*/()='


SZ = 28
bin_n = 16 # Number of bins


def resize(src_img, size):
   """Resize source image to given size"""
   # 获取原图像的宽、高
   h, w = src_img.shape

   if h >= size and w >= size:
      # 图像缩放
      dst_img = cv2.resize(src_img, (size, size), interpolation=cv2.INTER_CUBIC)
   elif h >= size:
      # 填充左右边缘
      dst_img = np.zeros(shape=(h, size), dtype=np.uint8)
      dst_img[:, (size-w)//2:(size-w)//2+w] = src_img
      dst_img = cv2.resize(dst_img, (size, size), interpolation=cv2.INTER_CUBIC)
   elif w >= size:
      # 填充上下边缘
      dst_img = np.zeros(shape=(size, w), dtype=np.uint8)
      dst_img[(size-h)//2:(size-h)//2+h, :] = src_img
      dst_img = cv2.resize(dst_img, (size, size), interpolation=cv2.INTER_CUBIC)
   else:
      # 填充四周
      dst_img = np.zeros(shape=(size, size), dtype=np.uint8)
      dst_img[(size-h)//2:(size-h)//2+h, (size-w)//2:(size-w)//2+w] = src_img

   return dst_img


def deskew(src_img):
   """Deskew the image using its second order moments"""
   m = cv2.moments(src_img)

   if abs(m['mu02']) < 1e-2:
      return src_img.copy()

   skew = m['mu11']/m['mu02']
   M = np.float32([[1, -skew, 0.5*SZ*skew], [0, 1, 0]])
   dst_img = cv2.warpAffine(src_img, M, (SZ, SZ), cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)

   return dst_img


def hog(image):
   """Get histogram of oriented gradient of given image"""
   gx = cv2.Sobel(image, cv2.CV_32F, 1, 0)
   gy = cv2.Sobel(image, cv2.CV_32F, 0, 1)
   mag, ang = cv2.cartToPolar(gx, gy)
   bins = np.int32(bin_n*ang/(2*np.pi)) # quantizing binvalues in (0, ..., 16)
   bin_cells = bins[:14,:14], bins[14:,:14], bins[:14,14:], bins[14:,14:]
   mag_cells = mag[:14,:14], mag[14:,:14], mag[:14,14:], mag[14:,14:]
   hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
   hist = np.hstack(hists) # hist is a 64bit vector
   return hist


def pre_process(src_img):
   """图像预处理"""
   img_resize = resize(src_img, SZ)
   img_invert = cv2.bitwise_not(img_resize) # 颜色翻转
   img_deskew = deskew(img_invert)
   hist = hog(img_deskew)
   return hist


def exprRecognize(src_img, filter_size):
   """手写算式识别"""
   # 灰度图
   gray = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY)

   # 二值化
   _, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
   binary_inv = cv2.bitwise_not(binary)

   # 中值滤波
   filter_size = int(filter_size[0][0]) if filter_size else 3
   binary_f = cv2.medianBlur(binary_inv, filter_size)

   # 查找字符区域
   contours, _ = cv2.findContours(binary_f, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

   # 遍历所有区域,寻找最大宽度
   w_max = 0
   for cnt in contours:
      _, _, w, _ = cv2.boundingRect(cnt)
      if w > w_max:
         w_max = w

   # 遍历所有区域,拼接x坐标接近的区域
   char_dict = {}
   for cnt in contours:
      x, y, w, h = cv2.boundingRect(cnt)
      x_mid = x + w//2 # 计算中点位置

      if not char_dict.keys() or all(np.abs(z - x_mid) > w_max/1.5 for z in char_dict.keys()):
         char_dict[x_mid] = cnt
      else:
         for z in char_dict.keys():
            if np.abs(z - x_mid) <= w_max/1.5:
               char_dict[z] = np.concatenate((char_dict[z], cnt), axis=0) # 拼接两个区域

   # 按照中点坐标,对字符进行排序
   char_dict = dict(sorted(char_dict.items(), key=lambda item: item[0]))

   # 遍历所有区域,提取字符
   dst_img = []
   for _, cnt in char_dict.items():
      x, y, w, h = cv2.boundingRect(cnt)
      roi = binary[y:y+h, x:x+w]
      dst_img.append(roi)

   expr = ''
   for char in dst_img:
      hist = pre_process(char)
      hist = np.array(hist, dtype=np.float32)
      result = model.predict(hist.reshape(-1, 4*bin_n))[1]
      expr += chars[int(result[0])]

   return dst_img, expr, eval(expr.replace('=', ''))

if __name__ == "__main__":
   demo = gr.Interface(
      fn=exprRecognize,
      inputs=[
         gr.Image(label="input image"), 
         gr.Radio(['3x3', '5x5', '7x7'], value='3x3')
      ],
      outputs=[
         gr.Gallery(label="charset", columns=[3], object_fit="contain", height="auto"),
         gr.Text(label="expression"),
         gr.Text(label="result")
      ],
      live=True
   )

   demo.launch()

  • 17
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洋洋Young

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值