使用Keras进行OCR | 附源码

光学字符识别(或光学字符阅读器,又名 OCR)是一种在过去二十年中用于识别和数字化图像中出现的字母和数字字符的技术。在行业中,这项技术可以帮助我们避免人工手动输入数据。

本文中,我们将了解如何将深度学习应用于OCR技术,以及对手写字符进行分类所需的步骤:

· 准备用于训练 OCR 模型的 0-9 和 A-Z 字母数据集。

· 加载数据集。

· 在数据集上成功训练 Keras 和 TensorFlow 模型。

· 绘制训练结果并可视化验证数据。

· 预测某些图像中存在的文本。

准备数据集

我们使用以下两个数据集来训练我们的 Keras 和 TensorFlow 模型。

· 0–9: MNIST

· A-Z: Kaggle

MNIST数据集

该数据集由 NIST 的特殊数据库3和特殊数据库1构建而成,其中包含手写数字的二进制图像。

It is built into popular deep learning frameworks, including Keras, TensorFlow, PyTorch, etc.
The MNIST dataset will allow us to recognize the digits 0–9.
Each of these digits is contained in a 28 x 28 grayscale image.

Kaggle数据集

This dataset takes the capital letters A–Z from NIST Special Database 19.
Kaggle also rescales them from 28 x 28 grayscale pixels to the same format as our MNIST data.

加载数据集

由于我们有两个独立的数据集,首先我们必须加载两个数据集并将它们组合成一个数据集。

加载 Kaggle 字母数据集

def load_az_dataset(dataset_path):
    # initialize the list of data and labels
    data = []
    labels = []
    # loop over the rows of the A-Z handwritten digit dataset
    for row in open(dataset_path):
        # parse the label and image from the row
        row = row.split(",")
        label = int(row[0])
        image = np.array([int(x) for x in row[1:]], dtype="uint8")


        # images are represented as single channel (grayscale) images
        # that are 28x28=784 pixels -- we need to take this flattened
        # 784-d list of numbers and reshape them into a 28x28 matrix
        image = image.reshape((28, 28))
        # update the list of data and labels
        data.append(image)
        labels.append(label)


        # convert the data and labels to NumPy arrays
        data = np.array(data, dtype="float32")
        labels = np.array(labels, dtype="int")
        # return a 2-tuple of the A-Z data and labels
        return (data, labels)

加载 MNIST 数字数据集

def load_zero_nine_dataset():
    # load the MNIST dataset and stack the training data and testing
    # data together (we'll create our own training and testing splits
    # later in the project)
    ((trainData, trainLabels), (testData, testLabels)) = mnist.load_data()
    data = np.vstack([trainData, testData])
    labels = np.hstack([trainLabels, testLabels])
    # return a 2-tuple of the MNIST data and labels
    return (data, labels)

合并数据集

...
# load all datasets
(azData, azLabels) = load_az_dataset(args["az"])
(digitsData, digitsLabels) = load_zero_nine_dataset()




# the MNIST dataset occupies the labels 0-9, so let's add 10 to every A-Z label to ensure the A-Z characters are not incorrectly labeled as digits
azLabels += 10




# stack the A-Z data and labels with the MNIST digits data and labels
data = np.vstack([azData, digitsData])
labels = np.hstack([azLabels, digitsLabels])
...

在数据集上训练模型

本文使用 Keras、TensorFlow 和 ResNet 架构来训练模型。

model = ResNet.build(32, 32, 1, len(le.classes_), (3, 3, 3),
                     (64, 64, 128, 256), reg=0.0005)
...
H = model.fit(
    aug.flow(trainX, trainY, batch_size=BS), validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS,
    class_weight=classWeight,
    verbose=1)
...

使用以下命令训练该模型需要大约 30-45 分钟。

python train_model.py --az dataset/a_z_handwritten_data.csv --model trained_ocr.model
[INFO] loading datasets...
[INFO] compiling model...
[INFO] training network...
Epoch 1/50
 34/437 [=>……………………….] — ETA: 7:40 — loss: 2.5050 — accuracy: 0.2989
...

可视化结果

我们将绘制一个可视化,以确保它正常工作。

预测

我们使用以下代码进行预测

python prediction.py — model trained_ocr.model — image images/hello_world.png
[INFO] H - 92.48%
[INFO] W - 54.50%
[INFO] E - 94.93%
[INFO] L - 97.58%
[INFO] 2 - 65.73%
[INFO] L - 96.56%
[INFO] R - 97.31%
[INFO] 0 - 37.92%
[INFO] L - 97.13%
[INFO] D - 97.83%

完整的源代码可以在这里看到:https://github.com/housecricket/how-to-train-OCR-with-Keras-and-TensorFlow

文件树结构如下所示

├── __init__.py
├── dataset
│   └── a_z_handwritten_data.csv
├── images
│   ├── hello_world.png
│   └── vietnamxinchao.png
├── models
│   ├── __init__.py
│   └── resnet.py
├── prediction.py
├── requirements.txt
├── train_model.py
├── trained_ocr.model
└── utils.py

总结

在本文中,我们使用Keras、TensorFlow和Python来训练OCR模型,是不是很简单~作为一个深度学习的入门算法,快来试试吧~

·  END  ·

HAPPY LIFE

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值