1、数据集划分
import random
train_txt = open("train.txt", "w", encoding="utf-8")
val_txt = open("val.txt", "w", encoding="utf-8")
with open("Label.txt", "r", encoding="utf-8") as f:
data = f.readlines()
f.close()
li_all = []
for da in data:
data1 = da.strip('\n')
li_all.append(data1)
count = len(data)
tra = int(0.9 \* count)
li = range(count)
print("训练集个数:", tra)
print("验证集个数:", count-tra)
train = random.sample(li, tra) # 随机从li列表中选取tra个数据
for i in li:
if i in train:
train_txt.write(li_all[i] + "\n")
else:
val_txt.write(li_all[i] + "\n")
2、修改配置文件
配置文件目录:./PaddleOCR/configs/det/det_mv3_db.yml
注: 这里的训练图像存放路径和标注label都在./data目录下。
Global:
use\_gpu: True # 默认是True
use\_xpu: false
use\_mlu: false
epoch\_num: 500 # ======================改====================================
log\_smooth\_window: 20
print\_batch\_step: 10
save\_model\_dir: ./output/db_mv3/ # ======================改====================================
save\_epoch\_step: 100 # ======================改====================================
# evaluation is run every 2000 iterations
eval\_batch\_step: [0, 2000]
cal\_metric\_during\_train: False
pretrained\_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ch_ppocr_mobile_v2.0_det_train/best_accuracy # ======================改====================================
checkpoints:
save\_inference\_dir:
use\_visualdl: False
infer\_img: doc/imgs_en/img_10.jpg
save\_res\_path: ./output/det_db/predicts_db.txt
Architecture:
model\_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model\_name: large
Neck:
name: DBFPN
out\_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance\_loss: true
main\_loss\_type: DiceLoss
alpha: 5
beta: 10
ohem\_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning\_rate: 0.001
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
thresh: 0.3
box\_thresh: 0.6
max\_candidates: 1000
unclip\_ratio: 1.5
Metric:
name: DetMetric
main\_indi