【OCR】EAST算法数据处理——ICDAR_2015数据集

背景

下载地址:https://rrc.cvc.uab.es/?ch=4&com=downloads

注意:该数据为开源数据,但需要注册一个账号,简单注册下即可下载;

介绍:用于文本检测任务,数据包含1000张训练样本以及500张测试样本;

所需数据展示

首先,该数据是用于EAST算法,需要处理成模型所需的数据格式;

原始标签数据:

在这里插入图片描述

可以看出,每张图像对应一个文本,其中有多个框的标注信息,分别表示四个坐标点和类别;

所需数据格式:

在这里插入图片描述

需要的数据分为三个部分,Score map、d_map、θ_map,取文本位置中的每个像素点作为标签样本;

代码实现

主函数实现

class custom_dataset(data.Dataset):
    # scale表示图像缩放到原来的1/4,图像大小缩放为512x512
    def __init__(self, img_path, gt_path, scale=0.25, length=512):
        super(custom_dataset, self).__init__()
        # 因为图片和标签数据都是对应的,并且在Linux系统下是乱序读入,所以需要排序
        self.img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))]
        self.gt_files  = [os.path.join(gt_path, gt_file) for gt_file in sorted(os.listdir(gt_path))]
        self.scale = scale
        self.length = length
        
	# 返回数据数量
    def __len__(self):
        return len(self.img_files)

    # __getitem__这个函数好用,可以建立个类,然后用下标来调用
    # 所有数据处理的调用都在这里进行
    def __getitem__(self, index):
    	# 读取标签文件中所有文本
        with open(self.gt_files[index], 'r', encoding='utf-8-sig') as f:
            lines = f.readlines()
        # 提取点和标签的函数
        vertices, labels = extract_vertices(lines)
        # 每次取得时候,都是随机取得,所以这样安排,[动态数据增强],注意到,动态得增强,省了不少空间 
        img = Image.open(self.img_files[index])
        # 为什么要随机缩放高度[0.8--1.2]:应该是数据增强
        img, vertices = adjust_height(img, vertices) 
        # 数据增强,旋转角度
        img, vertices = rotate_img(img, vertices)
        # 随机剪切(缩放)512x512的图片
        img, vertices = crop_img(img, vertices, labels, self.length) 
        # 函数功能:修改亮度、对比度和饱和度
        # 这里均值和方差都设置为0.5,实际上不太合理,但对后续影响不大
        transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
        # 得到所需要的训练数据,geo_map包括d1-d4和θ,ignored_map表示忽略部分
        score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
        return transform(img), score_map, geo_map, ignored_map

拓展

  • Pytorch中的transforms有多种用法,下面是一篇总结博客;

文章:https://blog.csdn.net/weixin_38533896/article/details/86028509

  • 图像缩放到原来的1/4是什么原因?

结合网络的结构分析可知,EAST模型通过编解码的形式,先下采样后上采样,最终输出特征图为输入的1/4,所以标签也需要和模型的输出相匹配;

在这里插入图片描述

可以看出,下采样的倍数为32倍,上采样的倍数为8倍,总体缩小了四倍,也就是输出特征图为原图的1/4;

extract_vertices函数

作用:得到标签文本中的点和类别信息

def extract_vertices(lines):
    '''
    Input:
        lines   : list of string info
    Output:
        vertices: vertices of text regions <numpy.ndarray, (n,8)>
        labels  : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
    '''
    labels = []
    vertices = []
    for line in lines:
        # 消除一些无用信息,取出前八个数作为点坐标信息
        vertices.append(list(map(int,line.rstrip('\r\n').lstrip('\xef\xbb\xbf').split(',')[:8])))
        label = 0 if '###' in line else 1
        labels.append(label)
    return np.array(vertices), np.array(labels)

拓展

  • 字符串的一些处理技巧:

    首先map函数可用于类型转换,例如:map(int,str[]),对于数组中的所有元素都转成int类型;

    rstrip(’\r\n’):表示去除右边的空格和换行符 lstrip(’\xef\xbb\xbf’):消除左端一个UTF-8 BOM字符

adjust_height函数

作用:随机缩放高度[0.8,1.2],数据增强的操作;

def adjust_height(img, vertices, ratio=0.2):
    '''
    Input:
        img         : PIL Image
        vertices    : vertices of text regions <numpy.ndarray, (n,8)>
        ratio       : height changes in [0.8, 1.2]
    Output:
        img         : adjusted PIL Image
        new_vertices: adjusted vertices
    '''
    # np.random.rand()生成一个0-1的随机数,所以这里的范围是0.8到1.2
    ratio_h = 1 + ratio * (np.random.rand() * 2 - 1)
    old_h = img.height
    # np.around表示四舍五入
    new_h = int(np.around(old_h * ratio_h))
    # 对图像的高度进行缩放
    img = img.resize((img.width, new_h), Image.BILINEAR)
    
    new_vertices = vertices.copy()
    if vertices.size > 0:
        # 这里是只对坐标点y的值进行处理
        # 注意:切片操作需要熟练一些,在数据处理时很常用
        new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * (new_h / old_h)
    return img, new_vertices

rotate_img函数

作用:旋转角度,数据增强的作用;

def rotate_img(img, vertices, angle_range=10):
    '''
    Input:
        img         : PIL Image
        vertices    : vertices of text regions <numpy.ndarray, (n,8)>
        angle_range : rotate range
    Output:
        img         : rotated PIL Image
        new_vertices: rotated vertices
    '''
    # 取出中心点坐标
    center_x = (img.width - 1) / 2
    center_y = (img.height - 1) / 2
    # 这里还是将旋转范围定位-10到10
    angle = angle_range * (np.random.rand() * 2 - 1)
    # 采用PIL Image中的旋转函数rotate
    img = img.rotate(angle, Image.BILINEAR)
    # 生成对应点想同纬度的全为0数组
    new_vertices = np.zeros(vertices.shape)
    for i, vertice in enumerate(vertices):
        # 这里对每个顶点也要进行翻转
        new_vertices[i,:] = rotate_vertices(vertice, -angle / 180 * math.pi, np.array([[center_x],[center_y]]))
    return img, new_vertices

# 对顶点进行旋转,这里传入的角度表示弧度,1度等于Π/180弧度
def rotate_vertices(vertices, theta, anchor=None):
    '''
    Input:    
        vertices: vertices of text region <numpy.ndarray, (8,)>
        theta   : angle in radian measure
        anchor  : fixed position during rotation
    Output:
        rotated vertices <numpy.ndarray, (8,)>
    '''
    v = vertices.reshape((4,2)).T
    if anchor is None:
        anchor = v[:,:1] 
    rotate_mat = get_rotate_mat(theta)
    # 点和矩阵进行点乘
    res = np.dot(rotate_mat, v - anchor)
    return (res + anchor).T.reshape(-1)

# 直接根据affine变换来把旋转矩阵填上
def get_rotate_mat(theta):
    # 返回一个旋转矩阵
    return np.array([[math.cos(theta), -math.sin(theta)], [math.sin(theta), math.cos(theta)]])

拓展

  • 旋转矩阵是怎样的?实际上是由初中所学的公式推导出来的,具体如下图所示:
  • 在这里插入图片描述

crop_img函数

作用:进行图像裁剪;

首先需要了解一下这个裁剪的流程:

第一步:先选出左上角的候选区域,作为点的选取,这里其实就是做一个范围内的裁剪;

在这里插入图片描述

第二步,判断裁剪框是否横跨文本框;

在这里插入图片描述

def crop_img(img, vertices, labels, length):
    '''
    Input:
        img         : PIL Image
        vertices    : vertices of text regions <numpy.ndarray, (n,8)>
        labels      : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
        length      : length of cropped image region
    Output:
        region      : cropped image region
        new_vertices: new vertices in cropped region
    '''
    h, w = img.height, img.width
    # confirm the shortest side of image >= length
    # 如果短边小于512,将短边放大到512
    if h >= w and w < length:
        img = img.resize((length, int(h * length / w)), Image.BILINEAR)
    elif h < w and h < length:
        img = img.resize((int(w * length / h), length), Image.BILINEAR)
    # 计算缩放后的比例
    ratio_w = img.width / w
    ratio_h = img.height / h
    # 断言(在程序判断中起重要作用)
    assert(ratio_w >= 1 and ratio_h >= 1)

    new_vertices = np.zeros(vertices.shape)
    if vertices.size > 0:
        # 文本框的标签x,y需要乘以一定比例
        new_vertices[:,[0,2,4,6]] = vertices[:,[0,2,4,6]] * ratio_w
        new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * ratio_h

    # find random position
    # 找到随机裁剪初始点的位置,也就是第一步中的候选区域
    remain_h = img.height - length
    remain_w = img.width - length
    flag = True
    cnt = 0
    # crop图片的时候,不能把文本前切为两半。
    # 尝试1000次
    while flag and cnt < 1000:
        cnt += 1
        start_w = int(np.random.rand() * remain_w)
        start_h = int(np.random.rand() * remain_h)
        # 判断是否横跨文本框
        flag = is_cross_text([start_w, start_h], length, new_vertices[labels==1,:])
    box = (start_w, start_h, start_w + length, start_h + length)
    region = img.crop(box)
    if new_vertices.size == 0:
        return region, new_vertices    
    
    # 标准化标签的坐标
    new_vertices[:,[0,2,4,6]] -= start_w
    new_vertices[:,[1,3,5,7]] -= start_h
    return region, new_vertices

# 作用:判断裁剪是否横跨文本框
def is_cross_text(start_loc, length, vertices):
    '''
    Input:
        start_loc: left-top position
        length   : length of crop image
        vertices : vertices of text regions <numpy.ndarray, (n,8)>
    Output:
        True if crop image crosses text region
    '''
    if vertices.size == 0:
        return False
    # 裁剪框的长和宽
    start_w, start_h = start_loc
    a = np.array([start_w, start_h, start_w + length, start_h, \
          start_w + length, start_h + length, start_w, start_h + length]).reshape((4,2))
    # p1表示矩阵a表示的矩形框
    p1 = Polygon(a).convex_hull
    for vertice in vertices:
        p2 = Polygon(vertice.reshape((4,2))).convex_hull
        # 可以算出两个矩形框的重叠区域(也就是IOU的值)
        inter = p1.intersection(p2).area
        # 0.0和1.0均不算横跨(包含或无交集的情况)
        p2_area = p2.area
        if p2.area == 0:
            p2_area = 0.00000001
        if 0.01 <= inter / p2_area <= 0.99: 
            return True
    return False

拓展

  • 如何判断裁剪框是否横框文本框呢?实际上就是求两个矩形的IOU值;

    采用Python的图形库——shapely

    第一步:创建两个矩形框(不规则图形也可以)

    from shapely.geometry import Polygon
    # convex_hull的作用是计算凸包
    a = Polygon([(0, 0), (0, 1), (1, 0), (1, 1)]).convex_hull
    b = Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]).convex_hull
    

    第二步:计算重合部分面积

    inter = a.intersection(b).area
    

    第三步:求IOU的值(也就是重合部分对目标的占比)

    inter / a.area
    

    当然这是一种计算IOU的快捷方式,基本上就是调用库来实现(可能对效率有一定影响,而且引用库也添加了额外的操作),实际上可用numpy实现IOU的计算,具体实现可自行查找;

get_score_geo函数

作用:获取训练所需map数据;

首先要说明一些概念,也是对部分代码的一个图示;

shrink_poly函数起到缩小文本框的作用,主要是实现所需数据展示部分中缩小文本框的效果;

计算θ的map值:

find_min_rect_angle函数是为了找到角度,原理图如下:

在这里插入图片描述

实现步骤:

1、通过遍历0-180°,每次旋转后都将外接矩形的面积保存下来;

2、和原始文本框的面积作对比,取出差值最小的十个旋转框;

3、计算十个框的拟合误差,返回误差最小的弧度;

计算d的map值:

计算文本框中的点离文本框边界的距离,需要先将整个图片旋转,保证文本框处于水平方向,原理图如下:

在这里插入图片描述

实现步骤:

1、先将图像进行旋转,保证文本框处于水平方向;

2、计算每个像素点距离边界的距离,这里如果小于0的值置为0,也就是负数为0;

def get_score_geo(img, vertices, labels, scale, length):
    '''
    Input:
        img     : PIL Image
	# quad 的8个顶点
        vertices: vertices of text regions <numpy.ndarray, (n,8)>
        labels  : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
        scale   : feature map / image	# 为0.25,根据网络结构理解
        length  : image length
    Output:
        score gt, geo gt, ignored
    '''
    # 产生零map,注意这里产生了5个map
    score_map = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)
    geo_map = np.zeros((int(img.height * scale), int(img.width * scale), 5), np.float32)
    # 忽略的map,训练中不需要用到
    ignored_map = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)

    #按照length和1/scale产生这个np.meshgrid是进行下采样缩小四倍
    # np.arange是按步长生成数组
    index = np.arange(0, length, int(1/scale))          #  间隔4个像素取一个点
    # np.meshgrid:从一个坐标向量中返回一个坐标矩阵
    index_x, index_y = np.meshgrid(index, index)
    ignored_polys = []
    polys = []
    
    # 遍历顶点,给geo_map赋值
    for i, vertice in enumerate(vertices):
        #记录需要忽略掉得四边形
        if labels[i] == 0:
            ignored_polys.append(np.around(scale * vertice.reshape((4,2))).astype(np.int32))
            continue        
        # 产生一个缩小0.3倍的poly(选定的正样本标签) 也就是文本框向内缩小0.3
        # scale只为fillPoly使用,实际标签不缩小4倍
        # shrink_poly的作用就是向内缩小0.3倍
        poly = np.around(scale * shrink_poly(vertice).reshape((4,2))).astype(np.int32)
        polys.append(poly)
        # 单个文本框的掩码,用于限定d1_map,d2_map,d3_map,d4_map中哪些位置应该赋值
        temp_mask = np.zeros(score_map.shape[:-1], np.float32)
        # 将全为0的mask中指定区域设置为1,也就是目标区域
        cv2.fillPoly(temp_mask, [poly], 1)
        
        # 通过遍历的方法,找到最小外接矩形,然后找到矩形的角度
        theta = find_min_rect_angle(vertice)
        
        # 旋转文本框,并旋转所有像素坐标,旋转为theta=0的水平状态,方便计算d
        # 找到矩形的角度对应到的旋转角度
        # get_rotate_mat作用:返回一个旋转矩阵(每个点的θ都是一样的)
        rotate_mat = get_rotate_mat(theta)
        rotated_vertices = rotate_vertices(vertice, theta)
        x_min, x_max, y_min, y_max = get_boundary(rotated_vertices)
        # 得到旋转后的x,y的值,为[512, 512]
        rotated_x, rotated_y = rotate_all_pixels(rotate_mat, vertice[0], vertice[1], length)
    
        # 计算d,负数表示在文本框外侧,置零
        d1 = rotated_y - y_min
        d1[d1<0] = 0
        d2 = y_max - rotated_y
        d2[d2<0] = 0
        d3 = rotated_x - x_min
        d3[d3<0] = 0
        d4 = x_max - rotated_x
        d4[d4<0] = 0
        # 每隔4个像素采样(index_y, index_x),乘以文本框缩小0.3后的掩码
        # 这里的作用是:通过掩码将0.3外的部分的d值去掉
        geo_map[:,:,0] += d1[index_y, index_x] * temp_mask
        geo_map[:,:,1] += d2[index_y, index_x] * temp_mask
        geo_map[:,:,2] += d3[index_y, index_x] * temp_mask
        geo_map[:,:,3] += d4[index_y, index_x] * temp_mask
        geo_map[:,:,4] += theta * temp_mask
    
    # 忽略标签为###的文本框
    cv2.fillPoly(ignored_map, ignored_polys, 1)
    # 得到[score_map] 全部文本框的掩码,即得分标签
    cv2.fillPoly(score_map, polys, 1)
    return torch.Tensor(score_map).permute(2,0,1), torch.Tensor(geo_map).permute(2,0,1), torch.Tensor(ignored_map).permute(2,0,1)

# 作用:缩小文本框
def shrink_poly(vertices, coef=0.3):
    '''
    Input:
        vertices: vertices of text region <numpy.ndarray, (8,)>
        coef    : shrink ratio in paper
    Output:
        v       : vertices of shrinked text region <numpy.ndarray, (8,)>
    '''
    x1, y1, x2, y2, x3, y3, x4, y4 = vertices
    # 获取每个点的短边,后面缩小范围的时候使用
    # 使用欧式距离计算每个点相邻边中最小的那条边
    r1 = min(cal_distance(x1,y1,x2,y2), cal_distance(x1,y1,x4,y4))
    r2 = min(cal_distance(x2,y2,x1,y1), cal_distance(x2,y2,x3,y3))
    r3 = min(cal_distance(x3,y3,x2,y2), cal_distance(x3,y3,x4,y4))
    r4 = min(cal_distance(x4,y4,x1,y1), cal_distance(x4,y4,x3,y3))
    r = [r1, r2, r3, r4]

    # 判断哪两个对边较长,后续move_points先移动长边
    # obtain offset to perform move_points() automatically
    if cal_distance(x1,y1,x2,y2) + cal_distance(x3,y3,x4,y4) > \
       cal_distance(x2,y2,x3,y3) + cal_distance(x1,y1,x4,y4):
        offset = 0 # two longer edges are (x1y1-x2y2) & (x3y3-x4y4)
    else:
        offset = 1 # two longer edges are (x2y2-x3y3) & (x4y4-x1y1)

    v = vertices.copy()
    # 先移动长边,再移动短边,这个先后有关系么?
    # 先移动短边上得两个点得话,得到得缩小后的四边形面积更小一些
    v = move_points(v, 0 + offset, 1 + offset, r, coef)
    v = move_points(v, 2 + offset, 3 + offset, r, coef)
    v = move_points(v, 1 + offset, 2 + offset, r, coef)
    v = move_points(v, 3 + offset, 4 + offset, r, coef)
    return v

# 作用:找到最佳弧度
def find_min_rect_angle(vertices):
    '''
    Input:
        vertices: vertices of text region <numpy.ndarray, (8,)>
    Output:
        the best angle <radian measure>
    '''
    # 每隔1度遍历所有角度
    angle_interval = 1
    angle_list = list(range(-90, 90, angle_interval))
    area_list = []
    for theta in angle_list: 
        rotated = rotate_vertices(vertices, theta / 180 * math.pi)
        x1, y1, x2, y2, x3, y3, x4, y4 = rotated
        # 直接计算平行于x轴和y轴的包围框
        temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
                    (max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
        area_list.append(temp_area)
    
    # 相当于argsort,获取从小到大的排序索引
    sorted_area_index = sorted(list(range(len(area_list))), key=lambda k : area_list[k])
    min_error = float('inf')
    best_index = -1
    rank_num = 10
    # 遍历前10个最小面积的矩形,计算拟合误差,并返回拟合误差最小的弧度
    # find the best angle with correct orientation
    for index in sorted_area_index[:rank_num]:
        rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
        # 计算拟合误差
        temp_error = cal_error(rotated)
        if temp_error < min_error:
            min_error = temp_error
            best_index = index
    return angle_list[best_index] / 180 * math.pi # 返回弧度

拓展:

  • np.meshgrid的作用是什么?

    在这里np.meshgrid的作用是表示间隔像素点的位置,具体案例如下:

    a = np.arange(0, 16, 4)			# array([ 0,  4,  8, 12])
    x, y = np.meshgrid(index, index)
    print(x)
    """
    array([[ 0,  4,  8, 12],
           [ 0,  4,  8, 12],
           [ 0,  4,  8, 12],
           [ 0,  4,  8, 12]])
    """
    print(y)
    """
    array([[ 0,  0,  0,  0],
           [ 4,  4,  4,  4],
           [ 8,  8,  8,  8],
           [12, 12, 12, 12]])
    """
    

    可以看出,(x,y)对应的矩阵就是对应每个像素点的索引,根据该索引得到的图像相当于是原图进行了1/4的下采样操作,可能会造成一定的信息损失;

  • cv2.fillPoly的作用?

    用于对图像中的掩码部分进行处理,修改为指定的值,下图为原理图:

在这里插入图片描述

总结

到这里ICDAR_2015的数据终于处理好了,本次数据处理是为了EAST模型使用,如果使用其他文本检测的模型,数据则需要做其他的处理;从整个代码可以看出,数据处理的工作量挺大的,并且还会遇到很多细节问题,出问题就会导致标签错误,整个任务都等于白做了!(在AI领域有一句话:数据是AI的基石,也是AI的上限)

其实在日常工作中,数据处理往往占了算法工程师一大部分时间,如何根据业务场景处理数据是最关键的一步;至于模型的选型,往往改动并不大,最繁琐也最费时间的步骤就是在数据处理这步!

所以作为算法工程师来说,对数据和业务场景的理解,对图像的基本处理,都是必备的技能。

  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值