天池OCR竞赛实践

OCR Paddle环境配置和训练笔记:天池OCR竞赛coggle baseline环境配置和训练

OCR简介

OCR全称为Optical Character Recognition,即光学字符识别,简单来说就是从实际生活、生产中的图像提取出文字信息(具体定义可参考百度百科:光学字符识别)。这个任务可难可易,简单情形下的OCR任务准确率已经很高了,但在复杂情况下想要提高准确程度仍是一件相当难的任务。而本次比赛中就属于比较复杂的情形,尤其是小票这种数据,会有很多光照、笔记和不规则的角度等的影响,难度比较大。

训练

关于baseline模型的训练可以参考之前的本文开头提供的笔记链接,也可以参考paddleocr官方文档。需要注意的是,baseline工程中给出的训练其实是对paddleocr中已有的ch_ppocr_server_v2.0_det模型的finetune,也就是在一个已经与训练好的检测模型上针对天池OCR比赛数据集的优化。因此我们可以很容易想到,除了检测模型外,我们还可以对文字识别的模型进行finetune。

对识别模型进行finetune:

  1. 首先需要下载模型,用wget命令或直接在浏览器中输入https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar下载ch_ppocr_server_v2.0_rec_train.tar
  2. 解压到tianchi-intel-PaddleOCR\inference目录下
  3. 到tianchi-intel-PaddleOCR\configs\rec\ch_ppocr_v2.0目录下,这里有两个配置文件rec_chinese_common_train_v2.0.yml和rec_chinese_lite_train_v2.0.yml,使用的模型不一样,我选择了rec_chinese_common_train_v2.0.yml,打开文件,修改其中的label_file_list和data_dir参数
  4. 在命令行输入python tools/train.py -c .\configs\rec\ch_ppocr_v2.0\rec_chinese_common_train_v2.0.yml -o Global.pretrain_weights=./inference/ch_ppocr_server_v2.0_rec_train开始训练

万万没想到识别模型训练起来这么慢,我把batchsize改成了8,算了一下完成一个epoch需要跑4天多……于是我直接改了改代码,让它跑完120个iteration就停止,这样的话模型并没有看过整个训练集,不知道效果会怎么样。


跑了一晚上,今天早上试着提交了,没想到分数竟然是0!是因为模型没有把所有数据都过一遍吗?
其实在训练过程中就已经感觉不对了,训练时的acc一直是0,是因为batchsize太小的缘故吗?
训练中的输出


找了一天,发现原来baseline方案里面的参数名写错了。“pretrained_model”写成了“pretrained_weights”。
另外,训练recognizer时,train_list和test_list都需要重新生成,官网上给出的格式为:

" 图像文件名 图像标注信息 "

train_data/train_0001.jpg 简单可依赖 train_data/train_0002.jpg
用科技让复杂的世界更简单

  • 注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错

这个其实只要在baseline的代码基础上改一下就可以了(原代码地址:生成训练集和验证集List,配合baseline使用):

idx = 0
for row in train.iloc[:-100].iterrows():
    path = json.loads(row[1]['原始数据'])['tfspath']
    img_path = IMAGE_PATH + path.split('/')[-1]
    labels = json.loads(row[1]['融合答案'])[0]
    if 'book' in path:
        print(path)
        idx += 1
        continue
        
    ann_text = ""
    for label in labels[:]:
        text = json.loads(label['text'])['text']
        coord = [int(float(x)) for x in label['coord']]
        
        ann_text+=text
#     break
#     print(ann_text, path)
    with open('./train_list_for_rec.txt', 'a+', encoding='utf-8') as up:
        up.write(f"image/{path.split('/')[-1]}\t{ann_text}\n")
        print("image/{}\t{}\n".format(path.split('/')[-1], ann_text))
idx = 0
for row in train.iloc[-100:].iterrows():
    path = json.loads(row[1]['原始数据'])['tfspath']
    img_path = IMAGE_PATH + path.split('/')[-1]
    labels = json.loads(row[1]['融合答案'])[0]
    if 'book' in path:
        print(path)
        idx += 1
        continue
        
    ann_text = ""
    for label in labels[:]:
        text = json.loads(label['text'])['text']
#         coord = [int(float(x)) for x in label['coord']]
        ann_text+=text
#     break
#     print(ann_text, path)
    with open('./test_list_for_rec.txt', 'a+', encoding='utf-8') as up:
        up.write(f"image/{path.split('/')[-1]}\t{ann_text}\n")
        print("image/{}\t{}\n".format(path.split('/')[-1], ann_text))

把格式改过来之后模型训练也快了很多,可能时因为标签从原先的json字典变成了单纯的字符串,数据量长度小了很多,不过训练一个epoch依然需要大约10个多小时的时间。
另外我还发现,预训练模型的识别能力很有限,官方给出的测试样本非常简单。
官方的测试样本
如果直接把比赛的图像丢进预训练模型中进行识别,效果非常差。于是我又去看了一下原始数据的标签,发现标签中的框都非常小,基本只能框住一行字,大概像下面这样:
框住的范围
对于这样的样本,识别结果的准确度还挺高的:
识别结果
所以究竟是否有必要对recognizer进行finetune还有待商榷,而且就算要finetune,也应该把所有样本都中的文字都截成一行一行的小图,问题是根据数据集中的矩形框坐标来看,这些矩形很有可能是倾斜的。所以应该如何根据一个倾斜的矩形框的顶点坐标来裁剪图像呢?


突然想到裁剪倾斜矩形框的代码可能不用自己从头想,我们只需要看看paddleocr在检测器输出了框的坐标后,是怎么把框内的图像喂给识别器的就行了。查看predict_system.py或predict_system_tianchi.py代码,发现里面有这样一个函数:

def get_rotate_crop_image(self, img, points):
        '''
        img_height, img_width = img.shape[0:2]
        left = int(np.min(points[:, 0]))
        right = int(np.max(points[:, 0]))
        top = int(np.min(points[:, 1]))
        bottom = int(np.max(points[:, 1]))
        img_crop = img[top:bottom, left:right, :].copy()
        points[:, 0] = points[:, 0] - left
        points[:, 1] = points[:, 1] - top
        '''
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
            img,
            M, (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img

而且在text_system的__call__函数中确实调用了这个函数:

img_crop = self.get_rotate_crop_image(ori_im, tmp_box)

这里ori_im是原始图像,tmp_box是检测器输出的矩形框坐标。

可见paddleocr是通过四个点的坐标算出矩形框的长宽,然后对图像做变换,使得原本必须用倾斜的矩形才能框出的区域,现在可以用正着的矩形框出。


终于把数据集做出来了,除了上面提到的那个函数,还需要再调用两个函数对坐标做一些处理才能截出正确的结果。另外感觉数据量比较大,所以只用了Xeon1OCR_round1_train_20210524.csvXeon1OCR_round1_train2_20210526.csv两个,没想到还是截出了将近20万个数据。完整的代码如下:

def order_points_clockwise(pts):
        """
        reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
        # sort the points based on their x-coordinates
        """
        xSorted = pts[np.argsort(pts[:, 0]), :]

        # grab the left-most and right-most points from the sorted
        # x-roodinate points
        leftMost = xSorted[:2, :]
        rightMost = xSorted[2:, :]

        # now, sort the left-most coordinates according to their
        # y-coordinates so we can grab the top-left and bottom-left
        # points, respectively
        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
        (tl, bl) = leftMost

        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
        (tr, br) = rightMost

        rect = np.array([tl, tr, br, bl], dtype="float32")
        return rect

def clip_det_res(points, img_height, img_width):
        for pno in range(points.shape[0]):
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
        return points

def filter_tag_det_res(dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
            box = self.clip_det_res(box, img_height, img_width)
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
            if rect_width <= 3 or rect_height <= 3:
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

def get_rotate_crop_image(img, points):
        '''
        img_height, img_width = img.shape[0:2]
        left = int(np.min(points[:, 0]))
        right = int(np.max(points[:, 0]))
        top = int(np.min(points[:, 1]))
        bottom = int(np.max(points[:, 1]))
        img_crop = img[top:bottom, left:right, :].copy()
#         cv2.imshow('img_crop', img_crop)
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()
#         points[:, 0] = points[:, 0] - left
#         points[:, 1] = points[:, 1] - top
#         return img_crop
        '''
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
            img,
            M, (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
#         cv2.imshow('M', M)
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img


%matplotlib inline
img_id = 0
img_num = 1
TRAIN_DATA_PATH = 'train_data/tianchi/image/'
TRAIN_DATA_FOR_REC = 'train_data/tianchi/train_data_for_rec/image/'
if not os.path.exists(TRAIN_DATA_FOR_REC):
       os.makedirs(TRAIN_DATA_FOR_REC)
for row in train.iloc[:-100].iterrows():
    path = json.loads(row[1]['原始数据'])['tfspath']
    img_path = TRAIN_DATA_PATH + path.split('/')[-1]
    img = cv2.imread(img_path)
    img_height, img_width = img.shape[0:2]
    print( img_num, img_id, img_path, img.shape )
    labels = json.loads(row[1]['融合答案'])[0]
    for label in labels[:]:
        text = json.loads(label['text'])['text']
        coord = [int(float(x)) for x in label['coord']]
        point = np.float32( [coord[:2], coord[2:4],coord[4:6], coord[-2:]] )
#         print( text, point )
        point = order_points_clockwise(point)
        point = clip_det_res(point, img_height, img_width)
        img_crop = get_rotate_crop_image(img, point)
#         cv2.imshow('img_crop', img_crop)
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()
        cv2.imwrite(TRAIN_DATA_FOR_REC+str(img_id)+'.jpg', img_crop)
        with open('train_data/tianchi/train_data_for_rec/train_list_for_rec.txt', 'a+', encoding='utf-8') as up:
            up.write(f"{str(img_id)}.jpg\t{text}\n")
        img_id += 1
    img_num += 1
#     if img_id == 1000:
#         break




print('*'*20, 'generate test samples', '*'*20)
TEST_DATA_FOR_REC = 'train_data/tianchi/train_data_for_rec/image/'
if not os.path.exists(TEST_DATA_FOR_REC):
       os.makedirs(TEST_DATA_FOR_REC)
for row in train.iloc[-100:].iterrows():
    path = json.loads(row[1]['原始数据'])['tfspath']
    img_path = TRAIN_DATA_PATH + path.split('/')[-1]
    img = cv2.imread(img_path)
    img_height, img_width = img.shape[0:2]
    print( img_num, img_id, img_path, img.shape )
    labels = json.loads(row[1]['融合答案'])[0]
    for label in labels[:]:
        text = json.loads(label['text'])['text']
        coord = [int(float(x)) for x in label['coord']]
        point = np.float32( [coord[:2], coord[2:4],coord[4:6], coord[-2:]] )
#         print( text, point )
        point = order_points_clockwise(point)
        point = clip_det_res(point, img_height, img_width)
        img_crop = get_rotate_crop_image(img, point)
        cv2.imwrite(TEST_DATA_FOR_REC+str(img_id)+'.jpg', img_crop)
        with open('train_data/tianchi/train_data_for_rec/test_list_for_rec.txt', 'a+', encoding='utf-8') as up:
            up.write(f"{str(img_id)}.jpg\t{text}\n")
        img_id += 1
    img_num += 1
#     if img_id == 1000:
#         break

在windows下很多文本的编码不是utf-8,而代码在读取数据时会按照utf-8解码,如下图所示(simple_dataset.py文件):
simple_dataset.py
所以在生成train_list和test_list时,要尽可能为函数指定encoding=‘utf-8’。


半夜发现还有坑……训练时必须要指定checkpoints,不然训练又会从头开始,之前的best_metric都为0,一旦eval完成,就会立即保存当前的模型。真的是无语了,那个pretrained_model和pretrain_weights到底是用来干什么的啊……


在Linux系统中可以直接在jupyter notebook中使用系统命令,语法就是在命令前面加一个“!”,但是在windows中系统下没有这些命令,其中有些有替换的命令,如“!ls”可以用“!dir”来代替,但是有些就没有。比如这次比赛的一个baseline中会用到的wget,此时需要在Windows中安装wget,参考方法如下:

原文网址:https://www.cnblogs.com/hzdx/p/6432161.html

  1. 下载[http://www.interlog.com/~tcharron/wgetwin.html]
  2. 解压到目录 比如我解压到D:\Tool\wget
  3. 添加wget环境变量,这样使用就更方便了,右键计算机->属性->高级系统设置->高级->环境变量->选中PATH->编辑,在最后添加 ;D:\Tool\wget (实际解压路径)
  4. 到此安装完成.


报错:

error Traceback (most recent call
last) in
21 point = np.array( [coord[:2], coord[2:4],coord[4:6], coord[-2:]] )
22 print( text, point )
—> 23 img_crop = get_rotate_crop_image(img, point)
24
25 break

in get_rotate_crop_image(img, points)
21 [img_crop_width, img_crop_height],
22 [0, img_crop_height]])
—> 23 M = cv2.getPerspectiveTransform(points, pts_std)
24 dst_img = cv2.warpPerspective(
25 img,

error: OpenCV(4.5.2)
D:\Build\OpenCV\opencv-4.5.2\modules\imgproc\src\imgwarp.cpp:3392:
error: (-215:Assertion failed) src.checkVector(2, CV_32F) == 4 &&
dst.checkVector(2, CV_32F) == 4 in function
‘cv::getPerspectiveTransform’

这个报错的原因是cv2.getPerspectiveTransform(points, pts_std)中points和pts_std的类型不一样,points是直接用np.array生成的坐标,pts_std是np.float32类型,把points也改成np.float32就好了。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值