完整的文本检测与识别 | 附源码

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达158fe52ca4dc5c154c663511f1c99a99.jpeg

另外,你还记得每家店铺都有独特的名字书写方式吗?像Gucci、Sears、Pantaloons和Lifestyle这样的知名品牌在其商标中使用了曲线或圆形字体。虽然这一切吸引了顾客,但对于执行文本检测和识别的深度学习(DL)模型来说,它确实提出了挑战。

当你读取横幅上的文字时,你会怎么做?你的眼睛首先会检测到文本的存在,找出每个字符的位置,然后识别这些字符。这正是一个DL模型需要做的!最近,OCR在深度学习中成为热门话题,其中每个新架构都在努力超越其他架构。

流行的基于深度学习的OCR模块Tesseract在结构化文本(如文件)上表现出色,但在花哨字体的曲线、不规则形状的文本方面却表现不佳。幸运的是,我们有Clova AI提供的这些出色的网络,它们在真实世界中出现的各种文本外观方面胜过了Tesseract。在本博客中,我们将简要讨论这些架构并学习如何将它们整合起来。

使用CRAFT进行文本检测

场景文本检测是在复杂背景中检测文本区域并用边界框标记它们的任务。CRAFT是一项2019年提出的主要目标是定位单个字符区域并将检测到的字符链接到文本实例的全称:Character-Region Awareness For Text detection。

0584d61ef267e7428a83e6a0635349d6.png

CRAFT采用了基于VGG-16的全卷积网络架构。简单来说,VGG16本质上是特征提取架构,用于将网络的输入编码成某种特征表示。CRAFT网络的解码段类似于UNet。它具有聚合低级特征的跳跃连接。CRAFT为每个字符预测两个分数:

  • 区域分数:顾名思义,它给出了字符的区域。它定位字符。 

  • 亲和力分数:'亲和力'是指物质倾向于与另一种物质结合的程度。

因此,亲和力分数将字符合并为单个实例(一个词)。CRAFT生成两个地图作为输出:区域级地图和亲和力地图。让我们通过示例来理解它们的含义:

df70c38be63f082f02896fb4d6d36aa5.png

输入图像

存在字符的区域在区域地图中标记出来:

d889db9b744ccdc4b59440e77bca98b8.png

区域地图

亲和力地图以图形方式表示相关字符。红色表示字符具有较高的亲和力,必须合并为一个词:

bd5a8eee20d120032c8a9b72136014c4.png

亲和力地图 

最后,将亲和力分数和区域分数组合起来,给出每个单词的边界框。坐标的顺序是:(左上)、(右上)、(右下)、(左下),其中每个坐标都是一个(x,y)对。

为什么不按照四点格式? 

看下面的图片:你能在仅有4个值的情况下定位“LOVE”吗?

c1aef675ad43148266334e5c33762c07.png

CRAFT是多语言的,这意味着它可以检测任何脚本中的文本。

文本识别:四阶段场景文本识别框架 

2019年,Clova AI发表了一篇关于现有场景文本识别(STR)数据集的不一致性,并提出了一个大多数现有STR模型都适用的统一框架的研究论文。

2566a2043a116435dca668982f0972e4.png

让我们讨论这四个阶段:

  1. 转换:记住我们正在处理的是景观文本,它是任意形状和曲线的。如果我们直接进行特征提取,那么它需要学习输入文本的几何形状,这对于特征提取模块来说是额外的工作。因此,STR网络应用了薄板样条(TPS)变换,并将输入文本规范化为矩形形状。 

  2. 特征提取:将变换后的图像映射到与字符识别相关的一组特征上。字体、颜色、大小和背景都被丢弃了。作者对不同的骨干网络进行了实验,包括ResNet、VGG和RCNN。 

  3. 序列建模:如果我写下'ba_',你很可能猜到填在空格处的字母可能是'd'、'g'、't',而不是'u'、'p'。我们如何教网络捕捉上下文信息?使用BiLSTMs!但是,BiLSTMs会占用内存,因此用户可以根据需要选择或取消这个阶段。 

  4. 预测:这个阶段从图像的已识别特征中估计输出字符序列。 

作者进行了几个实验。他们为每个阶段选择了不同的网络。准确性总结在下表中:

7bb61586d3a896547b5f201800ddb5e1.png

代码

CRAFT预测每个单词的边界框。四阶段STR将单个单词(作为图像)作为输入,并预测字母。如果你正在处理单个字的图像(如CUTE80),使用这些DL模块的OCR将会很轻松。

步骤1:安装要求

16cd9ce141f9c06ccdf48d59da9f58bc.png

步骤2:克隆代码库

f581191c6090d24b37398b66fd010c87.png

步骤3:修改以返回检测框分数 

CRAFT返回高于一定分数阈值的边界框。如果你想看到每个边界框的分数值,我们需要对原始库进行一些更改。打开克隆的CRAFT Repository中的craft_utils.py文件。你需要将第83行和第239行更改为如下所示。

"""Modify to Return Scores of Detection Boxes"""


"""  
Copyright (c) 2019-present NAVER Corp.
MIT License
"""


# -*- coding: utf-8 -*-
import numpy as np
import cv2
import math


""" auxilary functions """
# unwarp corodinates
def warpCoord(Minv, pt):
    out = np.matmul(Minv, (pt[0], pt[1], 1))
    return np.array([out[0]/out[2], out[1]/out[2]])
""" end of auxilary functions """




def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
    # prepare data
    linkmap = linkmap.copy()
    textmap = textmap.copy()
    img_h, img_w = textmap.shape


    """ labeling method """
    ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
    ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)


    text_score_comb = np.clip(text_score + link_score, 0, 1)
    nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)


    det = []
    det_scores = []
    mapper = []
    for k in range(1,nLabels):
        # size filtering
        size = stats[k, cv2.CC_STAT_AREA]
        if size < 10: continue


        # thresholding
        if np.max(textmap[labels==k]) < text_threshold: continue


        # make segmentation map
        segmap = np.zeros(textmap.shape, dtype=np.uint8)
        segmap[labels==k] = 255
        segmap[np.logical_and(link_score==1, text_score==0)] = 0   # remove link area
        x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
        w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
        niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
        sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
        # boundary check
        if sx < 0 : sx = 0
        if sy < 0 : sy = 0
        if ex >= img_w: ex = img_w
        if ey >= img_h: ey = img_h
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
        segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)


        # make box
        np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
        rectangle = cv2.minAreaRect(np_contours)
        box = cv2.boxPoints(rectangle)


        # align diamond-shape
        w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
        box_ratio = max(w, h) / (min(w, h) + 1e-5)
        if abs(1 - box_ratio) <= 0.1:
            l, r = min(np_contours[:,0]), max(np_contours[:,0])
            t, b = min(np_contours[:,1]), max(np_contours[:,1])
            box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)


        # make clock-wise order
        startidx = box.sum(axis=1).argmin()
        box = np.roll(box, 4-startidx, 0)
        box = np.array(box)


        det.append(box)
        mapper.append(k)
        det_scores.append(np.max(textmap[labels==k]))


    return det, labels, mapper, det_scores


def getPoly_core(boxes, labels, mapper, linkmap):
    # configs
    num_cp = 5
    max_len_ratio = 0.7
    expand_ratio = 1.45
    max_r = 2.0
    step_r = 0.2


    polys = []  
    for k, box in enumerate(boxes):
        # size filter for small instance
        w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
        if w < 10 or h < 10:
            polys.append(None); continue


        # warp image
        tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
        M = cv2.getPerspectiveTransform(box, tar)
        word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
        try:
            Minv = np.linalg.inv(M)
        except:
            polys.append(None); continue


        # binarization for selected label
        cur_label = mapper[k]
        word_label[word_label != cur_label] = 0
        word_label[word_label > 0] = 1


        """ Polygon generation """
        # find top/bottom contours
        cp = []
        max_len = -1
        for i in range(w):
            region = np.where(word_label[:,i] != 0)[0]
            if len(region) < 2 : continue
            cp.append((i, region[0], region[-1]))
            length = region[-1] - region[0] + 1
            if length > max_len: max_len = length


        # pass if max_len is similar to h
        if h * max_len_ratio < max_len:
            polys.append(None); continue


        # get pivot points with fixed length
        tot_seg = num_cp * 2 + 1
        seg_w = w / tot_seg     # segment width
        pp = [None] * num_cp    # init pivot points
        cp_section = [[0, 0]] * tot_seg
        seg_height = [0] * num_cp
        seg_num = 0
        num_sec = 0
        prev_h = -1
        for i in range(0,len(cp)):
            (x, sy, ey) = cp[i]
            if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
                # average previous segment
                if num_sec == 0: break
                cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
                num_sec = 0


                # reset variables
                seg_num += 1
                prev_h = -1


            # accumulate center points
            cy = (sy + ey) * 0.5
            cur_h = ey - sy + 1
            cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
            num_sec += 1


            if seg_num % 2 == 0: continue # No polygon area


            if prev_h < cur_h:
                pp[int((seg_num - 1)/2)] = (x, cy)
                seg_height[int((seg_num - 1)/2)] = cur_h
                prev_h = cur_h


        # processing last segment
        if num_sec != 0:
            cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]


        # pass if num of pivots is not sufficient or segment widh is smaller than character height 
        if None in pp or seg_w < np.max(seg_height) * 0.25:
            polys.append(None); continue


        # calc median maximum of pivot points
        half_char_h = np.median(seg_height) * expand_ratio / 2


        # calc gradiant and apply to make horizontal pivots
        new_pp = []
        for i, (x, cy) in enumerate(pp):
            dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
            dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
            if dx == 0:     # gradient if zero
                new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
                continue
            rad = - math.atan2(dy, dx)
            c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
            new_pp.append([x - s, cy - c, x + s, cy + c])


        # get edge points to cover character heatmaps
        isSppFound, isEppFound = False, False
        grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
        grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
        for r in np.arange(0.5, max_r, step_r):
            dx = 2 * half_char_h * r
            if not isSppFound:
                line_img = np.zeros(word_label.shape, dtype=np.uint8)
                dy = grad_s * dx
                p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
                cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
                if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
                    spp = p
                    isSppFound = True
            if not isEppFound:
                line_img = np.zeros(word_label.shape, dtype=np.uint8)
                dy = grad_e * dx
                p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
                cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
                if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
                    epp = p
                    isEppFound = True
            if isSppFound and isEppFound:
                break


        # pass if boundary of polygon is not found
        if not (isSppFound and isEppFound):
            polys.append(None); continue


        # make final polygon
        poly = []
        poly.append(warpCoord(Minv, (spp[0], spp[1])))
        for p in new_pp:
            poly.append(warpCoord(Minv, (p[0], p[1])))
        poly.append(warpCoord(Minv, (epp[0], epp[1])))
        poly.append(warpCoord(Minv, (epp[2], epp[3])))
        for p in reversed(new_pp):
            poly.append(warpCoord(Minv, (p[2], p[3])))
        poly.append(warpCoord(Minv, (spp[2], spp[3])))


        # add to final result
        polys.append(np.array(poly))


    return polys


def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
    boxes, labels, mapper, det_scores = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)


    if poly:
        polys = getPoly_core(boxes, labels, mapper, linkmap)
    else:
        polys = [None] * len(boxes)


    return boxes, polys, det_scores


def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
    if len(polys) > 0:
        polys = np.array(polys)
        for k in range(len(polys)):
            if polys[k] is not None:
                polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
    return polys

步骤4:从CRAFT中删除参数解析器 

打开test.py并修改如下所示。我们删除了参数解析器。

"""Modify to Remove Argument Parser"""


"""  
Copyright (c) 2019-present NAVER Corp.
MIT License
"""


# -*- coding: utf-8 -*-
import sys
import os
import time
import argparse


import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable


from PIL import Image


import cv2
from skimage import io
import numpy as np
import craft_utils
import imgproc
import file_utils
import json
import zipfile


from craft import CRAFT


from collections import OrderedDict
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, args, refine_net=None):
    t0 = time.time()


    # resize
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
    ratio_h = ratio_w = 1 / target_ratio


    # preprocessing
    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
    x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
    if cuda:
        x = x.cuda()


    # forward pass
    with torch.no_grad():
        y, feature = net(x)


    # make score and link map
    score_text = y[0,:,:,0].cpu().data.numpy()
    score_link = y[0,:,:,1].cpu().data.numpy()


    # refine link
    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0,:,:,0].cpu().data.numpy()


    t0 = time.time() - t0
    t1 = time.time()


    # Post-processing
    boxes, polys, det_scores = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)


    # coordinate adjustment
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None: polys[k] = boxes[k]


    t1 = time.time() - t1


    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)


    if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))


    return boxes, polys, ret_score_text, det_scores

步骤5:编写一个单独的脚本,将图像名称和检测框坐标保存到CSV文件中 

这将帮助我们裁剪需要作为四阶段STR输入的单词。它还帮助我们将所有与边界框和文本相关的信息保存在一个地方。创建一个新文件(我将其命名为pipeline.py)并添加以下代码。

import sys
import os
import time
import argparse


import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable


from PIL import Image


import cv2
from skimage import io
import numpy as np
import craft_utils
import test
import imgproc
import file_utils
import json
import zipfile
import pandas as pd


from craft import CRAFT


from collections import OrderedDict


from google.colab.patches import cv2_imshow


def str2bool(v):
    return v.lower() in ("yes", "y", "true", "t", "1")


#CRAFT
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')


args = parser.parse_args()




""" For test images in a folder """
image_list, _, _ = file_utils.get_files(args.test_folder)


image_names = []
image_paths = []


#CUSTOMISE START
start = args.test_folder


for num in range(len(image_list)):
  image_names.append(os.path.relpath(image_list[num], start))




result_folder = './Results'
if not os.path.isdir(result_folder):
    os.mkdir(result_folder)


if __name__ == '__main__':


    data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])
    data['image_name'] = image_names


    # load net
    net = CRAFT()     # initialize


    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
    else:
        net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))


    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False


    net.eval()


    # LinkRefiner
    refine_net = None
    if args.refine:
        from refinenet import RefineNet
        refine_net = RefineNet()
        print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
        if args.cuda:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
            refine_net = refine_net.cuda()
            refine_net = torch.nn.DataParallel(refine_net)
        else:
            refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))


        refine_net.eval()
        args.poly = True


    t = time.time()


    # load data
    for k, image_path in enumerate(image_list):
        print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
        image = imgproc.loadImage(image_path)


        bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)
        
        bbox_score={}


        for box_num in range(len(bboxes)):
          key = str (det_scores[box_num])
          item = bboxes[box_num]
          bbox_score[key]=item


        data['word_bboxes'][k]=bbox_score
        # save score text
        filename, file_ext = os.path.splitext(os.path.basename(image_path))
        mask_file = result_folder + "/res_" + filename + '_mask.jpg'
        cv2.imwrite(mask_file, score_text)


        file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)


    data.to_csv('/content/Pipeline/data.csv', sep = ',', na_rep='Unknown')
    print("elapsed time : {}s".format(time.time() - t))

pandas DataFrame(变量data)在单独的列中存储图像名称和其中包含的单词的边界框。我们去掉了图像的完整路径,只保留了图像,以避免笨拙。你当然可以根据自己的需要进行定制。现在可以运行脚本了:

3a4a587fa76dbd1c72d518f87bd483aa.png

在这个阶段,CSV看起来像这样。对于每个检测,我们都存储了一个包含分数:坐标的Python字典。

47b1b8697b5557b403f21b0aaa48b09a.png

步骤6:裁剪单词

现在我们有了每个框的坐标和分数。我们可以设置一个阈值,裁剪我们希望识别字符的单词。创建一个新脚本crop_images.py。请记住,在提到的地方添加你的路径。裁剪的单词保存在'dir'文件夹中。我们为每个图像创建一个文件夹,并以以下格式保存从中裁剪的单词:<父图像>_<由下划线分隔的8个坐标> 这样做可以帮助你跟踪每个裁剪单词来自哪个图像。

import os
import numpy as np
import cv2
import pandas as pd
from google.colab.patches import cv2_imshow


def crop(pts, image):


  """
  Takes inputs as 8 points
  and Returns cropped, masked image with a white background
  """
  rect = cv2.boundingRect(pts)
  x,y,w,h = rect
  cropped = image[y:y+h, x:x+w].copy()
  pts = pts - pts.min(axis=0)
  mask = np.zeros(cropped.shape[:2], np.uint8)
  cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
  dst = cv2.bitwise_and(cropped, cropped, mask=mask)
  bg = np.ones_like(cropped, np.uint8)*255
  cv2.bitwise_not(bg,bg, mask=mask)
  dst2 = bg + dst


  return dst2




def generate_words(image_name, score_bbox, image):


  num_bboxes = len(score_bbox)
  for num in range(num_bboxes):
    bbox_coords = score_bbox[num].split(':')[-1].split(',\n')
    if bbox_coords!=['{}']:
      l_t = float(bbox_coords[0].strip(' array([').strip(']').split(',')[0])
      t_l = float(bbox_coords[0].strip(' array([').strip(']').split(',')[1])
      r_t = float(bbox_coords[1].strip(' [').strip(']').split(',')[0])
      t_r = float(bbox_coords[1].strip(' [').strip(']').split(',')[1])
      r_b = float(bbox_coords[2].strip(' [').strip(']').split(',')[0])
      b_r = float(bbox_coords[2].strip(' [').strip(']').split(',')[1])
      l_b = float(bbox_coords[3].strip(' [').strip(']').split(',')[0])
      b_l = float(bbox_coords[3].strip(' [').strip(']').split(',')[1].strip(']'))
      pts = np.array([[int(l_t), int(t_l)], [int(r_t) ,int(t_r)], [int(r_b) , int(b_r)], [int(l_b), int(b_l)]])
      
      if np.all(pts) > 0:
        
        word = crop(pts, image)
        
        folder = '/'.join( image_name.split('/')[:-1])


        #CHANGE DIR
        dir = '/content/Pipeline/Crop Words/'


        if os.path.isdir(os.path.join(dir + folder)) == False :
          os.makedirs(os.path.join(dir + folder))


        try:
          file_name = os.path.join(dir + image_name)
          cv2.imwrite(file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l), word)
          print('Image saved to '+file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l))
        except:
          continue


data=pd.read_csv('PATH TO CSV')


start = PATH TO TEST IMAGES


for image_num in range(data.shape[0]):
  image = cv2.imread(os.path.join(start, data['image_name'][image_num]))
  image_name = data['image_name'][image_num].strip('.jpg')
  score_bbox = data['word_bboxes'][image_num].split('),')
  generate_words(image_name, score_bbox, image)

运行脚本:

de7e86563fb5f3e7e4cd6ec0a05c48bf.png

步骤6:识别(最后!) 

现在你可以在裁剪的单词上盲目运行识别模块了。但如果你想让事情更有条理,修改如下所示。我们在每个图像文件夹中创建一个.txt文件,并将识别的单词与裁剪图像的名称一起保存。除此之外,预测的单词也保存在我们维护的CSV中。

import string
import argparse


import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F


from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


import pandas as pd
import os


def demo(opt):




    """Open csv file wherein you are going to write the Predicted Words"""
    data = pd.read_csv('/content/Pipeline/data.csv')


    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)


    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)


    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))


    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)


    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)


            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)


                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)


            else:
                preds = model(image, text_for_pred, is_train=False)


                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)


            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score'
            
            print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')


            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                
                
                start = PATH TO CROPPED WORDS
                path = os.path.relpath(img_name, start)


                folder = os.path.dirname(path)


                image_name=os.path.basename(path)


                file_name='_'.join(image_name.split('_')[:-8])


                txt_file=os.path.join(start, folder, file_name)                
                
                log = open(f'{txt_file}_log_demo_result_vgg.txt', 'a')
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]


                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}')
                log.write(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n')


            log.close()
  


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')
    parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
    parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
    parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
    """ Data processing """
    parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
    parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
    parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
    parser.add_argument('--rgb', action='store_true', help='use rgb input')
    parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
    parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
    parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
    """ Model Architecture """
    parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
    parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
    parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
    parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
    parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
    parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
    parser.add_argument('--output_channel', type=int, default=512,
                        help='the number of output channel of Feature extractor')
    parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')


    opt = parser.parse_args()


    """ vocab / character number configuration """
    if opt.sensitive:
        opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).


    cudnn.benchmark = True
    cudnn.deterministic = True
    opt.num_gpu = torch.cuda.device_count()
    # print (opt.image_folder)


    # pred_words=demo(opt)
    demo(opt)

从Clova AI STR Github Repository下载权重后,你可以运行以下命令:

e68007794efbabca3182f3e5af9361e3.png

我们选择了这种网络组合,因为它们的准确性很高。现在CSV看起来是这样的。pred_words有检测框坐标和预测的单词,用冒号分隔。

ea30b60b27b8c7953ac1a5f5404da74b.png

结论

我们已经集成了两个准确的模型,创建了一个单一的检测和识别模块。现在你有了预测的单词和它们的边界框在一个单独的列中,你可以以任何你想要的方式对齐文本!

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值