基于 EasyOCR 微调 CRAFT 文本检测模型

点击下方卡片,关注“小白玩转Python”公众号

EasyOCR 中的 CRAFT 模型(用于文本检测的字符区域感知)用于检测图像内文本的边界框。然后,这些边界框被发送到 EasyOCR 的文本识别器模块,以读取每个图像中的文本。CRAFT 模块和文本识别器共同构成了 EasyOCR 的管道。在上一篇文章中,我向我们展示了如何微调文本识别器模块,而本文将重点介绍如何微调 EasyOCR 的 CRAFT 模块。一起,微调 EasyOCR 模块的两个模块可以帮助构建强大的 OCR 引擎,我们可以将其用于所需的用例。

背景

本文教我们如何微调组成 EasyOCR 的两个模块之一,即文本检测模块。文本检测模块检测图像中的边界框,这是良好 OCR 引擎的重要组成部分。如果我们希望 OCR 在我们的特定用例中具有最佳性能,则微调 CRAFT 模型至关重要。

创建数据集

要运行微调,首先需要一个数据集。为了方便遵循本教程,我在这个 Google Drive 文件夹中创建了一个示例数据集,其中包含一些我们可以使用的示例图像。如果我们使用它,我们可以直接跳到下一部分。

如果我们想使用自己的数据集,可以按照我这里的教程创建一个合成收据数据集,该数据集可用于微调 EasyOCR 的文本检测模块(CRAFT)和文本识别模型。然后,我们可以使用以下函数将我上面教程中的格式转换为 CRAFT 模型所需的格式:

import json
import os
import torch


def load_dataset(train_split = 0.8, folder_path = "CRAFT_data/" ):
    #first train
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    else:
        print("FOLDER ALREADY EXISTS, WILL NOT OVERWRITE SO RETURNING. REMOVE FOLDER IF YOU WANT TO RECREATE")
        return


    #TODO open labels.json file
    path_to_labels_file = "./ReceiptData/Receipt2Synth/labels.json"
    with open(path_to_labels_file) as f:
        labels = json.load(f)


    full_folder_path = folder_path + "ch4_training_localization_transcription_gt"
    if not os.path.exists(full_folder_path):
        os.makedirs(full_folder_path)
    split_idx = int(len(labels) * train_split)


    for idx, (image_info) in enumerate(labels):
        if (idx == split_idx):
            full_folder_path = folder_path + "ch4_test_localization_transcription_gt"
            if not os.path.exists(full_folder_path):
                os.makedirs(full_folder_path)


        image_path, labels_and_bbs = image_info["image_path"], image_info["strings_and_bbs"]
        image_name = image_path.split('/')[-1].split('.')[0]
        # now take all labels and bounding boxes and convert to correct format
        with open(f'{full_folder_path}/gt_{image_name}.txt', 'w', encoding='utf-8') as f:
            for (lbl,x1,y1,w,h) in labels_and_bbs: # x1,y1 is top left corner (2 is top right, 3 is bot right, 4 is bot left)
                x2 = x1 + w
                y2 = y1
                x3 = x1 + w
                y3 = y1 + h
                x4 = x1
                y4 = y1 + h
                full_label_string = [x1,y1,x2,y2,x3,y3,x4,y4,lbl]
                full_label_string = f"{x1},{y1},{x2},{y2},{x3},{y3},{x4},{y4},{lbl}"
                #then save to txt
                f.write(full_label_string + '\n')
            f.close()


load_dataset()
  • <path to the labels.json made in the tutorial above>是我在创建合成收据数据集的教程中制作的 labels.json 文件的路径。

这将创建一个名为CRAFT_data的文件夹,其中包含子文件夹ch4_training_localization_transcription_gt和ch4_test_localization_transcription_gt。请注意,CRAFT_data文件夹名称是可选的,但其他两个文件夹名称是必需的,因为训练脚本需要这些文件夹名称。我们还可以使用 train_split 参数更改训练和测试样本之间的分割。使用 train_split=0.8,数据集被分成 80% 的训练和 20% 的测试。

要理解代码,我们必须了解 EasyOCR 所需的边界框的坐标格式。格式为 (x1,y1,x2,y2,x3,y3,x4,y4),其中边界框上的坐标如下图所示(1 是左上角,2 是右上角,3 是右下角,4 是左下角):

561ba0fbc901e665d027119f8ede68c0.jpeg

EasyOCR 微调如何预期边界坐标

如果我们按照我的教程制作数据集,我们将获得 (x1,y1,w,h) 格式,其中 (x1,y1) 是边界框的左上角,w 是边界框的宽度,h 是边界框的高度。然后,上面的代码将 (x1,y1,w,h) 格式转换为 (x1,y1,x2,y2,x3,y3,x4,y4) 并将其存储在 txt 文件中。然后,在CRAFT_data文件夹中,添加文件夹:

  1. ch4_训练图像

  2. ch4_测试图像

并将教程中的 jpg 图像添加到这两个文件夹中。请注意,我们只应添加具有相应基本事实的图像。CRAFT_data文件夹的文件夹结构应如下所示:

05abd03e5e80fa6af45c6a1f42ce0457.jpeg

CRAFT_data 文件夹的文件夹结构

请注意,我们可能还有更多图像,出于可视化目的,我仅删除了其中的大部分图像。

克隆 Git 存储库

然后,克隆EasyOCR GitHub 存储库,其中包含可用于进行训练的代码。我们可以使用以下命令进行克隆:

git clone https://github.com/JaidedAI/EasyOCR.git

现在我们需要将CRAFT_data文件夹放在路径中:EasyOCR/trainer/craft,其中EasyOCR是我们刚刚克隆的 GitHub 存储库。

下载预先训练的模型

要微调模型,首先需要一个预先训练好的模型来进行微调。可以从这个 Google Drive下载(这不是我的 Google Drive,但它链接在有关微调 CRAFT 模型的 README底部)。将其放在我们的存储库中的某个位置(我在EasyOCR/trainer/craft下创建了一个名为pretrained_model的新文件夹,并将其放在该文件夹中)。

配置 YAML 文件

yaml 文件提供了用于训练的配置。要配置 yaml 文件以在自定义数据集上进行训练,请选择EasyOCR/trainer/craft/config文件夹中的custom_data_train.yaml文件。然后我们必须对文件进行 3 处更改。请注意,我假设我们使用的文件夹名称与Craft_data 和我相同。

1. 将第 6 行的data_root_dir更改为如下行:

data_root_dir: "./CRAFT_data/"

2. 将train 下第 17 行的ckpt_path更改为上一节中下载的预训练模型的路径。对我来说,这行是:

# ...
train:
  # ...  
  ckpt_path: "./pretrained_model/CRAFT_clr_amp_14000.pth"
  # ...
# ...

3. 将第 92 行的test/custom_data下的test_data_dir更改为如下行:

# ...
test:
  trained_model : null
  custom_data:
    # ...
    test_data_dir: "./CRAFT_data/"
    # ...

我的完整 yaml 文件如下:

wandb_opt: False


results_dir: "./exp/"
vis_test_dir: "./vis_result/"


data_root_dir: "./CRAFT_data/" 
# data_root_dir: "./CRAFT_data/" #TODO change when running full training
score_gt_dir: None # "/data/ICDAR2015_official_supervision" #TODO use this?? 
mode: "weak_supervision" #could also be None
# mode: None #could also be None




train:
  backbone : vgg
  use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
  synth_data_dir: "/data/SynthText/"
  synth_ratio: 5
  real_dataset: custom
  ckpt_path: "./pretrained_model/CRAFT_clr_amp_14000.pth"
  # ckpt_path: "./pretrained_model/craft_mlt_25k.pth" # the baseline model from EasyOCR
  # eval_interval: 1000
  eval_interval: 10
  batch_size: 1 #TODO changed from 5 -> 1
  st_iter: 0
  end_iter: 30 #25000
  lr: 0.0001
  lr_decay: 7500
  gamma: 0.2
  weight_decay: 0.00001
  num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
  amp: True
  loss: 2
  neg_rto: 0.3
  n_min_neg: 5000
  data:
    vis_opt: False
    pseudo_vis_opt: False
    output_size: 768
    do_not_care_label: ['###', '']
    mean: [0.485, 0.456, 0.406]
    variance: [0.229, 0.224, 0.225]
    enlarge_region : [0.5, 0.5] # x axis, y axis
    enlarge_affinity: [0.5, 0.5]
    gauss_init_size: 200
    gauss_sigma: 40
    watershed:
      version: "skimage"
      sure_fg_th: 0.75
      sure_bg_th: 0.05
    syn_sample: -1
    custom_sample: -1
    syn_aug:
      random_scale:
        range: [1.0, 1.5, 2.0]
        option: False
      random_rotate:
        max_angle: 20
        option: False
      random_crop:
        version: "random_resize_crop_synth"
        option: True
      random_horizontal_flip:
        option: False
      random_colorjitter:
        brightness: 0.2
        contrast: 0.2
        saturation: 0.2
        hue: 0.2
        option: True
    custom_aug:
      random_scale:
        range: [ 1.0, 1.5, 2.0 ]
        option: False
      random_rotate:
        max_angle: 20
        option: True
      random_crop:
        version: "random_resize_crop"
        scale: [0.03, 0.4]
        ratio: [0.75, 1.33]
        rnd_threshold: 1.0
        option: True
      random_horizontal_flip:
        option: True
      random_colorjitter:
        brightness: 0.2
        contrast: 0.2
        saturation: 0.2
        hue: 0.2
        option: True


test:
  trained_model : null
  custom_data:
    test_set_size: 500 #NOTE not used in train.py
    # test_data_dir: "./CRAFT_data/"
    test_data_dir: "./CRAFT_data/" # TODO change when running full training
    text_threshold: 0.75
    low_text: 0.5
    link_threshold: 0.2
    canvas_size: 2240
    mag_ratio: 1.75
    poly: False
    cuda: True
    vis_opt: False

运行微调

现在我们终于可以运行微调了。要运行训练,请在终端中使用以下命令:

python train.py --yaml=<your yaml file name>
  • <your yaml file name>是我们的 yaml 文件的名称

因此,对我来说,命令是:

python train.py --yaml=custom_data_train

训练后,model.pth 文件将存储在exp/<config name>文件夹中(<config name>是 config.yaml 文件的名称)

使用微调模型

在上一节中微调了 .pth 模型后,我们现在想在 EasyOCR 模型中使用它。为此,首先,将我们的 .pth 文件移动到要从中运行 EasyOCR 的目录。然后,我们可以使用以下代码来使用新的微调 OCR 模型:

#code to load custom craft model (from )
import easyocr
from easyocr.detection import get_detector, get_textbox
import torch
import cv2


save_pth= torch.load('<your fine-tuned .pth model>')
model = save_pth["craft"]
torch.save(model , "<a new name for your fine-tuned .pth model>") 


reader = easyocr.Reader(
    lang_list=["en"],
    detector=False,
)
reader.get_detector, reader.get_textbox = get_detector, get_textbox
reader.detector = reader.initDetector("<a new name for your fine-tuned .pth model>")


img = cv2.imread("<path to an image>")
reader.readtext(img, detail=0)
  • <your fine-tuned.pth model>是上一节中制作的.pth模型文件的路径

  • <a new name for your fine-tuned .pth model>是我们的模型的新名称(可以是任何我们想要的名称,只要它是一个 .pth 文件)

  • <path to an image>是要执行 OCR 的图像的路径

关于代码的说明

如果我们只使用英文字符,本节可能更有意义,但仍然值得一读。我正在对挪威语文本进行微调,其中包含一些非英文字符。这导致了很多难以调试的问题。因此,我建议检查我们在CRAFT_data文件夹中使用的.txt文件是否有任何未知字符,这些字符将在我们的 VSCode txt 文件阅读器中显示为问号。如果我们注意到这些未知字符,并且我们自己写入了文件,可以通过使用 encoding=”utf-8” 进行写入来修复,如下所示:

# Old method which gives errors:
with open('file.txt', 'w') as f:
  f.write("Some text with random norwegian characters æ ø å")


# New method which does not give errors
with open('file2.txt', 'w', encoding='utf-8') as f:
  f.write("Some text with random norwegian characters æ ø å")

包含错误的文件将如下所示:

da0ad9907c26a427eb08bb6f7b343661.jpeg包含 3 个未知字符的文件的样子,我们可以在图片右侧看到带有问号的未知字符

如果使用正确的编码写入文件,则会得到以下结果:

570e6c9bfeb7574750d8cea8c407fe04.jpeg

正确编码字符后看起来是什么样子

结论

在本教程中,我们学习了如何微调 EasyOCR 的文本检测模型,该模型称为 CRAFT。文本检测模型接收图像并输出包含图像中文本的边界框,这是 EasyOCR 的重要组成部分。微调 CRAFT 模型可以大大改善 OCR 引擎的结果。

·  END  ·

HAPPY LIFE

d6c6d510768cfe0d99a1676278742656.png

本文仅供学习交流使用,如有侵权请联系作者删除

  • 7
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值