PaddleOCR训练和测试自己的数据集_paddleocr训练自己的数据集

1、数据集划分

import random

train_txt = open("train.txt", "w", encoding="utf-8")
val_txt = open("val.txt", "w", encoding="utf-8")
with open("Label.txt", "r", encoding="utf-8") as f:
    data = f.readlines()
    f.close()

li_all = []
for da in data:
    data1 = da.strip('\n')
    li_all.append(data1)

count = len(data)
tra = int(0.9 \* count)
li = range(count)
print("训练集个数:", tra)
print("验证集个数:", count-tra)
train = random.sample(li, tra)   # 随机从li列表中选取tra个数据

for i in li:
    if i in train:
        train_txt.write(li_all[i] + "\n")
    else:
        val_txt.write(li_all[i] + "\n")

2、修改配置文件

配置文件目录:./PaddleOCR/configs/det/det_mv3_db.yml
注: 这里的训练图像存放路径和标注label都在./data目录下。

Global:
  use\_gpu: True   # 默认是True
  use\_xpu: false
  use\_mlu: false
  epoch\_num: 500  # ======================改====================================
  log\_smooth\_window: 20
  print\_batch\_step: 10
  save\_model\_dir: ./output/db_mv3/    # ======================改====================================
  save\_epoch\_step: 100          # ======================改====================================
  # evaluation is run every 2000 iterations
  eval\_batch\_step: [0, 2000]
  cal\_metric\_during\_train: False
  pretrained\_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ch_ppocr_mobile_v2.0_det_train/best_accuracy    # ======================改====================================
  checkpoints:
  save\_inference\_dir:
  use\_visualdl: False
  infer\_img: doc/imgs_en/img_10.jpg
  save\_res\_path: ./output/det_db/predicts_db.txt

Architecture:
  model\_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model\_name: large
  Neck:
    name: DBFPN
    out\_channels: 256
  Head:
    name: DBHead
    k: 50

Loss:
  name: DBLoss
  balance\_loss: true
  main\_loss\_type: DiceLoss
  alpha: 5
  beta: 10
  ohem\_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    learning\_rate: 0.001
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box\_thresh: 0.6
  max\_candidates: 1000
  unclip\_ratio: 1.5

Metric:
  name: DetMetric
  main\_indicator: hmean

Train:
  dataset:
    name: SimpleDataSet
    data\_dir: ./     # ======================改====================================
    label\_file\_list:
      - ./data/train.txt    # ======================改====================================
    ratio\_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img\_mode: BGR
          channel\_first: False
      - DetLabelEncode: # Class handling label
      - IaaAugment:
          augmenter\_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [640, 640]
          max\_tries: 50
          keep\_ratio: true
      - MakeBorderMap:
          shrink\_ratio: 0.4
          thresh\_min: 0.3
          thresh\_max: 0.7
      - MakeShrinkMap:
          shrink\_ratio: 0.4
          min\_text\_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep\_keys: ['image', 'threshold\_map', 'threshold\_mask', 'shrink\_map', 'shrink\_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop\_last: False
    batch\_size\_per\_card: 1    # 16 =====================改====================================
    num\_workers: 1    # ======================改====================================
    use\_shared\_memory: True

Eval:
  dataset:
    name: SimpleDataSet
    data\_dir: ./     # ======================改====================================
    label\_file\_list:
      - ./data/val.txt    # ======================改====================================
    transforms:
      - DecodeImage: # load image
          img\_mode: BGR
          channel\_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
          image\_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep\_keys: ['image', 'shape', 'polys', 'ignore\_tags']
  loader:
    shuffle: False
    drop\_last: False
    batch\_size\_per\_card: 1 # must be 1 ======================改====================================
    num\_workers: 8    # ======================改====================================
    use\_shared\_memory: True


3、训练自己的数据集

python tools/train.py -c configs/det/det_mv3_db.yml

4、断点续训

python tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3_0606/latest_accuracy

四、文本识别训练自己的数据集

1、制作数据集

对标注好的图像进行处理,如下:

import os
from PIL import Image

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

img_path = "E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/data/"  # 图像目录
img_txt_path = "train\_data/rec\_ch/"    # 标注好的图像的txt文件目录
img_save_path = "E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/rec\_ch/"   # 处理后的图像和txt存储目录,即训练集目录
mkdir(img_save_path)
li = ["train", "test"]   # 待处理的txt标注文件
for txt in li:
    ocr_li = []
    img_save = img_save_path + txt + "/"   # 图像保存路径
    mkdir(img_save)
    with open(f"E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/{txt}.txt", "r", encoding="utf-8") as f:
        data = f.readlines()
        f.close()
    new_txt = open(f"{img\_save\_path}rec\_gt\_{txt}.txt", "w", encoding="utf-8")   # 新的txt标注文件存放处
    for da in data:
        da_new = da.strip("\n")
        img_name, img_info = da_new.split(" ")
        img_name = img_name.split("/")[-1]
        img = Image.open(img_path + img_name)
        img_info = eval(img_info)  # 将字符串转换为列表
        i = 1
        for di in img_info:
            new_name = img_name[:-4] + "\_" + str(i) + ".jpg"
            img_new_path = img_txt_path + txt + "/" + new_name  # txt文件中的图像路径+名字
            label = di["transcription"]
            points = di["points"]
            # 获取四个点的 x 和 y 坐标
            x_coordinates = [point[0] for point in points]
            y_coordinates = [point[1] for point in points]
            # 计算剪切区域的坐标
            left = min(x_coordinates)
            upper = min(y_coordinates)
            right = max(x_coordinates)
            lower = max(y_coordinates)
            if label not in ocr_li:
                ocr_li.append(label)
                new_txt.write(img_new_path + " " + label + "\n")
                new_img = img.crop((left, upper, right, lower))   # 左上角和右下角的坐标
                new_img.save(img_save + new_name)
                i += 1

训练图像和txt存储路径:
在这里插入图片描述
txt文件格式例子:
在这里插入图片描述

2、修改配置文件

配置文件路径:configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

Global:
  debug: false
  use\_gpu: true
  epoch\_num: 500    # 800 ======================修改=====================
  log\_smooth\_window: 20
  print\_batch\_step: 10
  save\_model\_dir: ./output/rec_ppocr_v3_distillation
  save\_epoch\_step: 100   # 3 ======================修改=====================
  eval\_batch\_step: [0, 2000]
  cal\_metric\_during\_train: true
  pretrained\_model: pretrain_models/rec_train/ch_PP-OCRv2_rec_slim/ch_PP-OCRv3_rec_train/best_accuracy   # ======================修改=====================
  checkpoints:
  save\_inference\_dir:
  use\_visualdl: false
  infer\_img: doc/imgs_words/ch/word_1.jpg
  character\_dict\_path: ppocr/utils/ppocr_keys_v1.txt
  max\_text\_length: &max\_text\_length 25
  infer\_mode: false
  use\_space\_char: true
  distributed: true
  save\_res\_path: ./output/rec/predicts_ppocrv3_distillation.txt


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Piecewise
    decay\_epochs : [700]
    values : [0.0005, 0.00005]
    warmup\_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:
  model\_type: &model\_type "rec"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze\_params: false
      return\_all\_feats: true
      model\_type: \*model\_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last\_conv\_stride: [1, 2]
        last\_pool\_type: avg
      Head:
        name: MultiHead
        head\_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden\_dims: 120
                use\_guide: True
              Head:
                fc\_decay: 0.00001
          - SARHead:
              enc\_dim: 512
              max\_text\_length: \*max\_text\_length
    Student:
      pretrained:
      freeze\_params: false
      return\_all\_feats: true
      model\_type: \*model\_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last\_conv\_stride: [1, 2]
        last\_pool\_type: avg
      Head:
        name: MultiHead
        head\_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden\_dims: 120
                use\_guide: True
              Head:
                fc\_decay: 0.00001
          - SARHead:
              enc\_dim: 512
              max\_text\_length: \*max\_text\_length
Loss:
  name: CombinedLoss
  loss\_config\_list:
  - DistillationDMLLoss:
      weight: 1.0
      act: "softmax"
      use\_log: true
      model\_name\_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi\_head: True
      dis\_head: ctc
      name: dml_ctc
  - DistillationDMLLoss:
      weight: 0.5
      act: "softmax"
      use\_log: true
      model\_name\_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi\_head: True
      dis\_head: sar
      name: dml_sar
  - DistillationDistanceLoss:
      weight: 1.0
      mode: "l2"
      model\_name\_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationCTCLoss:
      weight: 1.0
      model\_name\_list: ["Student", "Teacher"]
      key: head_out
      multi\_head: True
  - DistillationSARLoss:
      weight: 1.0
      model\_name\_list: ["Student", "Teacher"]
      key: head_out
      multi\_head: True

PostProcess:
  name: DistillationCTCLabelDecode
  model\_name: ["Student", "Teacher"]
  key: head_out
  multi\_head: True
**自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。**

**深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!**

**因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**

![img](https://img-blog.csdnimg.cn/img_convert/72047784e9d3d8e07c5bccbb5df296d1.png)

 

![img](https://img-blog.csdnimg.cn/img_convert/5e415e4002a6b562257205e7be117097.png)

![img](https://img-blog.csdnimg.cn/img_convert/46506ae54be168b93cf63939786134ca.png)

![img](https://img-blog.csdnimg.cn/img_convert/252731a671c1fb70aad5355a2c5eeff0.png)

![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)

![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)

 

**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**

**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**

**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)**

mg-blog.csdnimg.cn/img_convert/252731a671c1fb70aad5355a2c5eeff0.png)

![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)

![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)

 

**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**

**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**

**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)**

<img src="https://img-community.csdnimg.cn/images/fd6ebf0d450a4dbea7428752dc7ffd34.jpg" alt="img" style="zoom:50%;" />
  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PaddleOCR是一个基于深度学习的开源OCR系统,主要用于文字识别任务。它支持多种语言的文字识别,并且可以通过训练自己的数据集来提升识别效果。 要训练自己的数据集,首先需要准备好包含文字的视频数据集。然后,我们需要将视频中的每一帧提取出来,并将每一帧上的文字区域标注。可以使用图像处理技术,如图像分割、文本检测等方法来实现文字区域的标注。 接下来,需要将数据集划分为训练集和验证集,一般按照80%的比例进行划分。然后,使用PaddleOCR提供的工具,如label_tools、utility等工具,将数据集的格式转化为PaddleOCR可识别的格式,比如txt或json格式。 接着,可以使用PaddleOCR提供的训练脚本进行模型的训练。在训练时,可以根据需要设置各种参数,如网络结构、学习率、训练轮数等。可以通过调整这些参数来优化模型的训练效果。 训练完成后,可以使用PaddleOCR提供的预测脚本来进行文字识别。首先,需要加载训练好的模型,并将视频中的每一帧输入到模型中进行识别。识别结果可以保存在文本文件中,或者在视频中进行展示。 总之,通过使用PaddleOCR训练自己的数据集,可以实现对视频中文字的识别。这对于一些需要从视频中提取文字信息的应用场景,如视频字幕生成、视频内容分析等具有重要的意义。需要注意的是,在训练过程中,数据集的质量对于模型效果有着重要的影响,因此需要尽量保证数据集的准确性和完整性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值