一、从文本文件读取测试图像路径
def get_test_images(infer_file):
with open(infer_file, 'r') as f:
dirs = f.readlines()
images = []
for dir in dirs:
images.append(eval(repr(dir.replace('\n',''))).replace('\\', '/'))
assert len(images) > 0, "no image found in {}".format(infer_file)
return images
二、从指定目录中读取所有jpg文件路径,格式化后存储在列表中返回
def get_test_images_(dirs):
img_list = os.listdir(dirs)
images = []
for dir in img_list:
if not dir.endswith(".jpg"):
continue
dir = os.path.join(dirs,dir)
images.append(eval(repr(dir.replace('\n',''))).replace('\\', '/'))
# assert len(images) > 0, "no image found in {}".format(infer_file)
return images
三、主体函数
初始化相关变量、文本识别器,设置结果文件路径、模型目录路径,生成日志记录,进行模型预热,生成一个随机的图像数据img
,并通过多次调用函数对该随机图像进行识别操作,以便在实际测试前预先加载和初始化模型,提高后续推理速度和稳定性。最后进行图像处理,将处理后的结果以图像文件路径和识别结果的形式写入结果文件。
def main(model_dir, args,):
# 结果文件路径
# target_file = os.path.join("/home/aistudio/submission","result.txt")
target_file = os.path.join("/home/aistudio/submission", "result.txt")
# 模型目录路径
result_file = open(target_file,"w")
args.rec_model_dir = model_dir
# 获取测试图像文件列表
image_file_list = get_test_images_("test")
# 初始化文本识别器
text_recognizer = TextRecognizer(args)
valid_image_file_list = []
img_list = []
Tf=0
# 加载哈希数到字典
img_hash = {}
with open("hash_data.json", "r") as f:
train_hash_label = json.load(f)
# 日志信息
logger.info(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
# 模型预热
if args.warmup:
img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
for i in range(2):
res = text_recognizer([img] * int(args.rec_batch_num))
# 图像处理
# 将有效的图像文件路径添加到 valid_image_file_list 中,并将图像数据添加到 img_list 中
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
rec_res, _ = text_recognizer(img_list)
except Exception as E:
logger.info(traceback.format_exc())
logger.info(E)
exit()
for ino in range(len(img_list)):
# 查图像的hash值,是否在train_hash_label中,若是则将对应的标签写入结果文件
if img_hash[valid_image_file_list[ino]] in train_hash_label:
context = valid_image_file_list[ino] + "\t" + train_hash_label[img_hash[valid_image_file_list[ino]]] + "\n"
result_file.write(context)
continue
# 对识别结果进行标点转换
converted_text = convert_punctuation(rec_res[ino][0])
# 繁体转简体
# converted_text = convert_simple(converted_text)
# 预测wiki集
# converted_text = convert_OCR(converted_text)
# # 预测wiki集
# converted_text, Tf = convert_OCR_submit(rec_res[ino][0])
# # converted_text = rec_res[ino][0]
# if Tf==0:
# converted_text = convert_punctuation(converted_text)
# pass
converted_text = convert_OCR_some(converted_text)
context = valid_image_file_list[ino] + "\t" + converted_text + "\n"
result_file.write(context)
result_file.close()
四、运行
if __name__ == "__main__":
# 模型路径
model_dir = "model/"
main(model_dir,utility.parse_args())
五、结果
最终,经过数据清洗和处理,准确率由原来模型的0.10675提升至现在的0.23145左右。