模型训练的教程也很简单,按照以下的官网教程就能实现:
欢迎来到 MMOCR 的中文文档! — MMOCR 1.0.1 文档
在这里提醒一下:文字检测的模型目前除咯db,mask_rcnn能跑,其余的drrg,textsnake模型我都跑不不通,大家如果遇到报错,就别跑咯。我搜集咯资料,这个原因是标注的信息报错,但是我把数据全部审查咯好几遍,还运用官网的代码转换还是报错,后面就没办法咯。(如果你能跑出来,希望你能给我支招)
assert point.shape[0]>=4
不过之前作为小白的我,就一直卡在数据集着一个步骤,现在详细讲一下:
路径就是mmocr-main/date/brand_data/这样的一个根目录,在根目录下有:train/imgaes和json文件(图中的txt文件就是我的paddleocr标注文件,还有output都是不相干的内容不用在意)
一般就是在ctw这个配置文件修改:
修改_base_四个配置文件,除咯dataset,其他的配置文件基本上不需要修改:
python tools/train.py configs/textdet/maskrcnn/mask-rcnn_resnet50_fpn_160e_ctw1500.py
顺便给一下paddleocr格式转换成mmocr格式的脚本:
import json
import numpy as np
import cv2
import os
# mode = 'train'
mode = 'test'
# f = open(f'/home/lixinru/Vir_env/mmocr-main/data/brand_data/brand_images/textdet_train.txt', 'r', encoding='utf-8').readlines()
f = open(f'/home/lixinru/Vir_env/mmocr-main/data/brand_data/brand_images/textdet_test.txt', 'r', encoding='utf-8').readlines()
# filename='/home/lixinru/Vir_env/mmocr-main/data/brand_data/brand_images/train/'
filename='/home/lixinru/Vir_env/mmocr-main/data/brand_data/brand_images/test/'
mmocr_label = {"metainfo": {"dataset_type": "TextDetDataset", "task_name": "textdet", "category": [{"id": 0, "name": "text"}]}, "data_list": []}
data_list = []
for line in f:
instances = {}
per_img_anno = []
anno = line.split('\t')
img_path = os.path.join(filename,anno[0])
print((img_path))
images_path=anno[0]
paddle_label = anno[1]
paddle_label = json.loads(paddle_label)
for i in range(len(paddle_label)):
per_box = {"polygon": [], "bbox": [], "bbox_label": 0}
polygon = []
points = paddle_label[i]['points']
bounding = list(cv2.boundingRect(np.array(points)))
if len(points) == 4:
for j in range(len(points)):
x, y = points[j][0], points[j][1]
polygon.append(int(x))
polygon.append(int(y))
# print()
else:
x1, y1 = bounding[0], bounding[1]
x2, y2 = bounding[0]+bounding[2]-1, bounding[1]
x3, y3 = bounding[0]+bounding[2]-1, bounding[1]+bounding[3]-1
x4, y4 = bounding[0], bounding[1]+bounding[3]-1
polygon.append(int(x1))
polygon.append(int(y1))
polygon.append(int(x2))
polygon.append(int(y2))
polygon.append(int(x3))
polygon.append(int(y3))
polygon.append(int(x4))
polygon.append(int(y4))
bbox = [int(bounding[0]), int(bounding[1]), int(bounding[0]+bounding[2]-1), int(bounding[1]+bounding[3]-1)]
per_box.update({"polygon": polygon})
per_box.update({"bbox": bbox})
per_box.update({"bbox_label": 0})
if paddle_label[i]["transcription"] == "###":
mark = bool(1)
per_box.update({"ignore": mark})
else:
mark = bool(0)
per_box.update({"ignore": mark})
per_img_anno.append(per_box)
# print()
instances.update({"instances": per_img_anno})
instances.update({"img_path": f'{mode}/{images_path}'})
# instances.update({"img_path": f'{images_path}'})
img = cv2.imread(img_path)
instances.update({"height": img.shape[0]})
instances.update({"weight": img.shape[1]})
instances.update({"seg_map": f'{mode}/gt_{img_path.split(".")[0]}.txt'})
# instances.update({"seg_map": f'gt_{img_path.split(".")[0]}.txt'})
data_list.append(instances)
mmocr_label['data_list'] = data_list
file_name = f"textdet_{mode}.json"
# if not os.path.exists(file_name):
# os.makedirs(file_name)
with open(f'/home/lixinru/Vir_env/mmocr-main/data/brand_data/brand_images/output/{file_name}', 'w') as f:
str = json.dumps(mmocr_label)
f.write(str)
f.close()
print("close")
判断paddleocr的标注是否是四个标注点的,如果报错我一开始说的问题,就用这个检查一下:
import json
def check_paddleocr_txt_annotations(txt_file_path):
try:
with open(txt_file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
for idx, line in enumerate(lines):
try:
line = line.strip().split('\t')
file_path = line[0]
# print(line[0])
annotation_str = line[1]
# print(line[1])
annotation_data = json.loads(annotation_str) # 使用json.loads将JSON字符串转换为Python对象
# 检查标注点数量
count=0
for annotation in annotation_data:
transcription = annotation["transcription"]
points = annotation["points"]
if len(points) != 4:
print(f"Invalid annotation found in file {file_path}:")
count+=1
print(f"Transcription: {transcription}")
print(f"Points: {points}")
print("------")
except Exception as e:
print(f"Error occurred at line {idx + 1}: {str(e)}")
except Exception as e:
print(f"Error occurred while checking annotation file: {str(e)}")
print('number:', count)
# 替换为你的txt标注文件路径
txt_file_path = '/home/lixinru/Vir_env/mmocr-main/data/brand_data/Label_all.txt'
check_paddleocr_txt_annotations(txt_file_path)