1、数据集划分
import random
train_txt = open("train.txt", "w", encoding="utf-8")
val_txt = open("val.txt", "w", encoding="utf-8")
with open("Label.txt", "r", encoding="utf-8") as f:
data = f.readlines()
f.close()
li_all = []
for da in data:
data1 = da.strip('\n')
li_all.append(data1)
count = len(data)
tra = int(0.9 \* count)
li = range(count)
print("训练集个数:", tra)
print("验证集个数:", count-tra)
train = random.sample(li, tra) # 随机从li列表中选取tra个数据
for i in li:
if i in train:
train_txt.write(li_all[i] + "\n")
else:
val_txt.write(li_all[i] + "\n")
2、修改配置文件
配置文件目录:./PaddleOCR/configs/det/det_mv3_db.yml
注: 这里的训练图像存放路径和标注label都在./data目录下。
Global:
use\_gpu: True # 默认是True
use\_xpu: false
use\_mlu: false
epoch\_num: 500 # ======================改====================================
log\_smooth\_window: 20
print\_batch\_step: 10
save\_model\_dir: ./output/db_mv3/ # ======================改====================================
save\_epoch\_step: 100 # ======================改====================================
# evaluation is run every 2000 iterations
eval\_batch\_step: [0, 2000]
cal\_metric\_during\_train: False
pretrained\_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ch_ppocr_mobile_v2.0_det_train/best_accuracy # ======================改====================================
checkpoints:
save\_inference\_dir:
use\_visualdl: False
infer\_img: doc/imgs_en/img_10.jpg
save\_res\_path: ./output/det_db/predicts_db.txt
Architecture:
model\_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model\_name: large
Neck:
name: DBFPN
out\_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance\_loss: true
main\_loss\_type: DiceLoss
alpha: 5
beta: 10
ohem\_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning\_rate: 0.001
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
thresh: 0.3
box\_thresh: 0.6
max\_candidates: 1000
unclip\_ratio: 1.5
Metric:
name: DetMetric
main\_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data\_dir: ./ # ======================改====================================
label\_file\_list:
- ./data/train.txt # ======================改====================================
ratio\_list: [1.0]
transforms:
- DecodeImage: # load image
img\_mode: BGR
channel\_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter\_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [640, 640]
max\_tries: 50
keep\_ratio: true
- MakeBorderMap:
shrink\_ratio: 0.4
thresh\_min: 0.3
thresh\_max: 0.7
- MakeShrinkMap:
shrink\_ratio: 0.4
min\_text\_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep\_keys: ['image', 'threshold\_map', 'threshold\_mask', 'shrink\_map', 'shrink\_mask'] # the order of the dataloader list
loader:
shuffle: True
drop\_last: False
batch\_size\_per\_card: 1 # 16 =====================改====================================
num\_workers: 1 # ======================改====================================
use\_shared\_memory: True
Eval:
dataset:
name: SimpleDataSet
data\_dir: ./ # ======================改====================================
label\_file\_list:
- ./data/val.txt # ======================改====================================
transforms:
- DecodeImage: # load image
img\_mode: BGR
channel\_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
image\_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep\_keys: ['image', 'shape', 'polys', 'ignore\_tags']
loader:
shuffle: False
drop\_last: False
batch\_size\_per\_card: 1 # must be 1 ======================改====================================
num\_workers: 8 # ======================改====================================
use\_shared\_memory: True
3、训练自己的数据集
python tools/train.py -c configs/det/det_mv3_db.yml
4、断点续训
python tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3_0606/latest_accuracy
四、文本识别训练自己的数据集
1、制作数据集
对标注好的图像进行处理,如下:
import os
from PIL import Image
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
img_path = "E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/data/" # 图像目录
img_txt_path = "train\_data/rec\_ch/" # 标注好的图像的txt文件目录
img_save_path = "E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/rec\_ch/" # 处理后的图像和txt存储目录,即训练集目录
mkdir(img_save_path)
li = ["train", "test"] # 待处理的txt标注文件
for txt in li:
ocr_li = []
img_save = img_save_path + txt + "/" # 图像保存路径
mkdir(img_save)
with open(f"E:/PycharmProjects/meter\_detection/PaddleOCR/train\_data/{txt}.txt", "r", encoding="utf-8") as f:
data = f.readlines()
f.close()
new_txt = open(f"{img\_save\_path}rec\_gt\_{txt}.txt", "w", encoding="utf-8") # 新的txt标注文件存放处
for da in data:
da_new = da.strip("\n")
img_name, img_info = da_new.split(" ")
img_name = img_name.split("/")[-1]
img = Image.open(img_path + img_name)
img_info = eval(img_info) # 将字符串转换为列表
i = 1
for di in img_info:
new_name = img_name[:-4] + "\_" + str(i) + ".jpg"
img_new_path = img_txt_path + txt + "/" + new_name # txt文件中的图像路径+名字
label = di["transcription"]
points = di["points"]
# 获取四个点的 x 和 y 坐标
x_coordinates = [point[0] for point in points]
y_coordinates = [point[1] for point in points]
# 计算剪切区域的坐标
left = min(x_coordinates)
upper = min(y_coordinates)
right = max(x_coordinates)
lower = max(y_coordinates)
if label not in ocr_li:
ocr_li.append(label)
new_txt.write(img_new_path + " " + label + "\n")
new_img = img.crop((left, upper, right, lower)) # 左上角和右下角的坐标
new_img.save(img_save + new_name)
i += 1
训练图像和txt存储路径:
txt文件格式例子:
2、修改配置文件
配置文件路径:configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
Global:
debug: false
use\_gpu: true
epoch\_num: 500 # 800 ======================修改=====================
log\_smooth\_window: 20
print\_batch\_step: 10
save\_model\_dir: ./output/rec_ppocr_v3_distillation
save\_epoch\_step: 100 # 3 ======================修改=====================
eval\_batch\_step: [0, 2000]
cal\_metric\_during\_train: true
pretrained\_model: pretrain_models/rec_train/ch_PP-OCRv2_rec_slim/ch_PP-OCRv3_rec_train/best_accuracy # ======================修改=====================
checkpoints:
save\_inference\_dir:
use\_visualdl: false
infer\_img: doc/imgs_words/ch/word_1.jpg
character\_dict\_path: ppocr/utils/ppocr_keys_v1.txt
max\_text\_length: &max\_text\_length 25
infer\_mode: false
use\_space\_char: true
distributed: true
save\_res\_path: ./output/rec/predicts_ppocrv3_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay\_epochs : [700]
values : [0.0005, 0.00005]
warmup\_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model\_type: &model\_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze\_params: false
return\_all\_feats: true
model\_type: \*model\_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last\_conv\_stride: [1, 2]
last\_pool\_type: avg
Head:
name: MultiHead
head\_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden\_dims: 120
use\_guide: True
Head:
fc\_decay: 0.00001
- SARHead:
enc\_dim: 512
max\_text\_length: \*max\_text\_length
Student:
pretrained:
freeze\_params: false
return\_all\_feats: true
model\_type: \*model\_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last\_conv\_stride: [1, 2]
last\_pool\_type: avg
Head:
name: MultiHead
head\_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden\_dims: 120
use\_guide: True
Head:
fc\_decay: 0.00001
- SARHead:
enc\_dim: 512
max\_text\_length: \*max\_text\_length
Loss:
name: CombinedLoss
loss\_config\_list:
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use\_log: true
model\_name\_pairs:
- ["Student", "Teacher"]
key: head_out
multi\_head: True
dis\_head: ctc
name: dml_ctc
- DistillationDMLLoss:
weight: 0.5
act: "softmax"
use\_log: true
model\_name\_pairs:
- ["Student", "Teacher"]
key: head_out
multi\_head: True
dis\_head: sar
name: dml_sar
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model\_name\_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationCTCLoss:
weight: 1.0
model\_name\_list: ["Student", "Teacher"]
key: head_out
multi\_head: True
- DistillationSARLoss:
weight: 1.0
model\_name\_list: ["Student", "Teacher"]
key: head_out
multi\_head: True
PostProcess:
name: DistillationCTCLabelDecode
model\_name: ["Student", "Teacher"]
key: head_out
multi\_head: True
**自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。**
**深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!**
**因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**
![img](https://img-blog.csdnimg.cn/img_convert/72047784e9d3d8e07c5bccbb5df296d1.png)
![img](https://img-blog.csdnimg.cn/img_convert/5e415e4002a6b562257205e7be117097.png)
![img](https://img-blog.csdnimg.cn/img_convert/46506ae54be168b93cf63939786134ca.png)
![img](https://img-blog.csdnimg.cn/img_convert/252731a671c1fb70aad5355a2c5eeff0.png)
![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)
![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)**
mg-blog.csdnimg.cn/img_convert/252731a671c1fb70aad5355a2c5eeff0.png)
![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)
![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)**
<img src="https://img-community.csdnimg.cn/images/fd6ebf0d450a4dbea7428752dc7ffd34.jpg" alt="img" style="zoom:50%;" />