OCR之PSE检测算法代码详解(一)

# dataloader add 3.0 scale
# dataloader add filer text
import numpy as np
from PIL import Image
from torch.utils import data
import util
import cv2
import random
import torchvision.transforms as transforms
import torch
import pyclipper
import Polygon as plg


ic15_root_dir = '/src/notebooks/train_data/weituoshu_20200115/AuthImg_up_train/'

ic15_train_data_dir = ic15_root_dir + 'image/'
ic15_train_gt_dir = ic15_root_dir + 'label/'
ic15_test_data_dir = ic15_root_dir + 'image/'
ic15_test_gt_dir = ic15_root_dir + 'label/'

random.seed(123456)

def get_img(img_path):
    try:
        img = cv2.imread(img_path)
        img = img[:, :, [2, 1, 0]]                   #前两维不变,后面的通道维度反向
    except Exception as e:
        print (img_path)
        raise
    return img

def get_bboxes(img, gt_path):
    h, w = img.shape[0:2]
    lines = util.io.read_lines(gt_path)
    bboxes = []
    tags = []
    for line in lines:
        line = util.str.remove_all(line, '\xef\xbb\xbf')
        line = util.str.remove_all(line, '\ufeff')
        gt = util.str.split(line, ',')
        if gt[-1][0] == '#':                               # #代表这条记录是有问题的,比如污损,模糊的,如果有问题就给该条记录打false标签,否则打true
            tags.append(False)
        else:
            tags.append(True)
        box = [int(gt[i]) for i in range(8)]               #[392, 70, 663, 64, 664, 120, 393, 126]
        box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4)   #[0.3534716  0.0473613  0.59783589 0.04330176 0.5987376  0.0811908 0.35437331 0.08525034]
        bboxes.append(box)
    return np.array(bboxes), tags

def random_horizontal_flip(imgs):
    if random.random() < 0.5:
        for i in range(len(imgs)):
            imgs[i] = np.flip(imgs[i], axis=1).copy()
    return imgs

def random_rotate(imgs):
    max_angle = 10
    angle = random.random() * 2 * max_angle - max_angle
    for i in range(len(imgs)):
        img = imgs[i]
        w, h = img.shape[:2]
        rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
        img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
        imgs[i] = img_rotation
    return imgs

def scale(img, long_size=2240):
    h, w = img.shape[0:2]
    scale = long_size * 1.0 / max(h, w)
    img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
    return img

def random_scale(img, min_size):
    h, w = img.shape[0:2]
    if max(h, w) > 1280:
        scale = 1280.0 / max(h, w)
        img = cv2.resize(img, dsize=None, fx=scale, fy=scale)

    h, w = img.shape[0:2]
    random_scale = np.array([0.5, 1.0, 2.0, 3.0])
    scale = np.random.choice(random_scale)
    if min(h, w) * scale <= min_size:
        scale = (min_size + 10) * 1.0 / min(h, w)
    img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
    return img

    def my_scale(img, min_size):
    h, w = img.shape[0:2]
    if max(h, w) > 1280:
        scale = 1280.0 / max(h, w)
        img = cv2.resize(img, dsize=None, fx=scale, fy=scale)

    h, w = img.shape[0:2]
    if min(h, w) * scale <= min_size:
        scale = (min_size + 10) * 1.0 / min(h, w)
    img = cv2.resize(img, dsize=None, fx=0.5, fy=0.5)
    return img

def random_crop(imgs, img_size):
    h, w = imgs[0].shape[0:2]
    th, tw = img_size
    if w == tw and h == th:
        return imgs
    
    if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
        tl = np.min(np.where(imgs[1] > 0), axis = 1) - img_size  #np.where(imgs[1] > 0)  
        tl[tl < 0] = 0
        br = np.max(np.where(imgs[1] > 0), axis = 1) - img_size
        br[br < 0] = 0
        br[0] = min(br[0], h - th)              # 首先找到所有非零部分的的坐标的最小的左上角(tl)和右下角(br),并且向下偏移img_size个单位
        br[1] = min(br[1], w - tw)              # 在这个区域中搜索起点(i,j),并以起点开始寻找终点(i+640,j+640)
        
        i = random.randint(tl[0], br[0])
        j = random.randint(tl[1], br[1])
    else:
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
    
    # return i, j, th, tw
    for idx in range(len(imgs)):
        if len(imgs[idx].shape) == 3:
            imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
        else:
            imgs[idx] = imgs[idx][i:i + th, j:j + tw]
    return imgs

def dist(a, b):
    return np.sqrt(np.sum((a - b) ** 2))

def perimeter(bbox):
    peri = 0.0
    for i in range(bbox.shape[0]):
        peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
    return peri

def shrink(bboxes, rate, max_shr=20):
    rate = rate * rate
    shrinked_bboxes = []
    for bbox in bboxes:
        area = plg.Polygon(bbox).area()
        peri = perimeter(bbox)

        pco = pyclipper.PyclipperOffset()
        pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
        
        shrinked_bbox = pco.Execute(-offset)
        if len(shrinked_bbox) == 0:
            shrinked_bboxes.append(bbox)
            continue
        
        shrinked_bbox = np.array(shrinked_bbox)[0]
        if shrinked_bbox.shape[0] <= 2:
            shrinked_bboxes.append(bbox)
            continue
        
        shrinked_bboxes.append(shrinked_bbox)
    
    return np.array(shrinked_bboxes)

class IC15Loader(data.Dataset):
    #is_transform是否等比例缩放及是否做一些变换  img_size图片大小  kernerl_num 一个文字区域渲染个数   min_scale:Vatti clip图形学多边形裁剪算法的一个参数
    def __init__(self, is_transform=False, img_size=None, kernel_num=7, min_scale=0.4):
        #__init__的作用就是把所有图片及txt的路径分别放在一个list中
        self.is_transform = is_transform
        
        self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
        self.kernel_num = kernel_num
        self.min_scale = min_scale

        data_dirs = [ic15_train_data_dir]
        gt_dirs = [ic15_train_gt_dir]

        self.img_paths = []
        self.gt_paths = []

        for data_dir, gt_dir in zip(data_dirs, gt_dirs):
            img_names = util.io.ls(data_dir, '.jpg')  #ls()return file names in a list,这里util里面的io文件为了避免与关键字冲突,所以使用io_命名,第二个参数是后缀名
            img_names.extend(util.io.ls(data_dir, '.png'))        
            img_names.extend(util.io.ls(data_dir, '.JPG'))

            img_paths = []
            gt_paths = []
            for idx, img_name in enumerate(img_names):
                img_path = data_dir + img_name
                img_paths.append(img_path)
                
#                 gt_name = 'gt_'+img_name.split('.')[0] + '.txt'
                gt_name = img_name.split('.')[0] + '.txt'
                gt_path = gt_dir + gt_name
                gt_paths.append(gt_path)

            self.img_paths.extend(img_paths)
            self.gt_paths.extend(gt_paths)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path = self.img_paths[index]
        gt_path = self.gt_paths[index]

        img = get_img(img_path)                  #img为RGB图像  
        bboxes, tags = get_bboxes(img, gt_path)  #get_bboxes返回很多个长度是8的list,及其相对应的标签tags [0.3534716  0.0473613  0.59783589 0.04330176 0.5987376  0.0811908 0.35437331 0.08525034]
                                                 #注意这里读取的bbox信息是相对位置,即坐标的范围为0~1,这样坐标不会受尺度变化影响。
        if self.is_transform:
            img = random_scale(img, self.img_size[0])  #等比例缩放,对图片进行尺度变换,增加训练的稳健性

        gt_text = np.zeros(img.shape[0:2], dtype='uint8')          #(2560, 1920) 新建一个和图片一样大小、全是0的黑色图片  gt_text
        training_mask = np.ones(img.shape[0:2], dtype='uint8')     #(2560, 1920) 新建一个和图片一样大小、全是1的白色图片  training_mask
        if bboxes.shape[0] > 0:                                    #bboxes.shape是  (26,8),判断的意思是如果该张图的label个数不为空的话,这里有26个
            bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), (bboxes.shape[0], bboxes.shape[1] // 2, 2)).astype('int32') #把(26,8)的bboxes被变成(26,4,2)4,2表示4个点坐标,8个值
            for i in range(bboxes.shape[0]):
                cv2.drawContours(gt_text, [bboxes[i]], -1, i + 1, -1)     #往黑色的图片gt_text填充标出的bboxes区域,第二个参数是list表示轮廓本身,-1所有轮廓  i+1颜色 -1线的宽度
                if not tags[i]:                                           #若该记录是false,则往白色的图片training_mask填充标出的bboxes区域
                    cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)

        gt_kernels = []
        for i in range(1, self.kernel_num):
            rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i  
            gt_kernel = np.zeros(img.shape[0:2], dtype='uint8')           
            kernel_bboxes = shrink(bboxes, rate)        #返回的kernel_bboxes 的格式(26,4,2) shrink的作用是根据论文中算法公式产生Ground truth maps G1至G7
            for i in range(bboxes.shape[0]):
                cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1)
            gt_kernels.append(gt_kernel)                ## gt_kernels是img这张图的n个kernels

        if self.is_transform:                           #False
            imgs = [img, gt_text, training_mask]
            imgs.extend(gt_kernels)

            imgs = random_horizontal_flip(imgs)
            imgs = random_rotate(imgs)
            imgs = random_crop(imgs, self.img_size)

            img, gt_text, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:]
        
        gt_text[gt_text > 0] = 1
        cv2.imwrite('img.jpg',img)
        cv2.imwrite('label.jpg',gt_text*255)
        gt_kernels = np.array(gt_kernels)

        # '''
        if self.is_transform:
            img = Image.fromarray(img)    
            img = img.convert('RGB')      
            img = transforms.ColorJitter(brightness = 32.0 / 255, saturation = 0.5)(img)
        else:
            img = Image.fromarray(img)    #将array转换成Image格式
            img = img.convert('RGB')      #然后再转成RGB模式

        img = transforms.ToTensor()(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)

        gt_text = torch.from_numpy(gt_text).float()
        gt_kernels = torch.from_numpy(gt_kernels).float()
        training_mask = torch.from_numpy(training_mask).float()
        # '''


        return img, gt_text, gt_kernels, training_mask

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值