数据扩充-旋转对应标注的矩形框旋转(Curve-Text-Detector)

本脚本主要用于Curve-Text-Detector数据旋转扩充,ctd标注数据转icdar2015标注数据

import cv2
import numpy as np
import math
def rotation_point(img, angle=15, point=None):
    cols = img.shape[1]
    rows = img.shape[0]
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    heightNew = int(cols * math.fabs(math.sin(math.radians(angle))) + rows * math.fabs(math.cos(math.radians(angle))))
    widthNew = int(rows * math.fabs(math.sin(math.radians(angle))) + cols * math.fabs(math.cos(math.radians(angle))))
    M[0, 2] += (widthNew - cols) / 2
    M[1, 2] += (heightNew - rows) / 2
    img = cv2.warpAffine(img, M, (widthNew, heightNew))
    a = M[:, :2]  ##a.shape (2,2)
    b = M[:, 2:]  ###b.shape(2,1)
    b = np.reshape(b, newshape=(1, 2))
    a = np.transpose(a)
    len_1 = len(point)
    point = np.reshape(point, newshape=(len_1, 2)) #point = np.reshape(point, newshape=(len(point) * 4, 2))
    point = np.dot(point, a) + b
    point = np.reshape(point, newshape=(len_1, 2))
    return img, point
img = cv2.imread('/media/d_2/everyday/0513/a7691515_s.jpg')
point = np.array([[100,201],[550,201],[235,251],[550,500],[100,500]],np.int32)
cv2.polylines(img,[point],True,(0,255,255))

point_tmp = point.copy()
img_rot,point_rt = rotation_point(img,15,point_tmp)
#point_rt.astype(np.int32)
point_rt_rt = point_rt.astype(int)
cv2.polylines(img_rot,[point_rt_rt],True,(255,0,255),5)
cv2.imshow('img_rot',img_rot)
cv2.imshow('img',img)
cv2.waitKey(0)

可以任意个数的点,python list作为函数参数有点儿奇怪
示例图片:
在这里插入图片描述

应用

一张图片 对应 一个txt
a.jpg a.txt
其中,a.txt存放了如下格式的数据:
442,162,763,202,0,5,52,4,104,3,156,2,208,1,260,1,321,0,320,36,268,36,216,37,164,38,112,38,60,39,1,40
11,167,404,210,0,18,63,15,127,12,191,9,255,6,319,3,393,0,392,32,328,33,264,35,200,37,136,39,72,41,6,43
520,131,791,163,0,4,44,3,88,2,132,2,176,1,220,0,271,0,269,28,225,28,181,29,137,30,93,30,49,31,2,32
每行32个数据,前4个是对应框的外界矩形左上右下坐标,后面14个数据对应目标上面的7个点坐标,再后面的14个数据对应目标下面的7个点坐标.有多少行代表就有多少个目标
本脚本功能就是对已经标好的数据,图片旋转,同时标好的数据处理好的数据对应旋转.扩充

list做函数参数有点儿奇怪
2019.05.24更新:
ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。) ヾ(。`Д´。)
这些表情能代表我的心情吗??重大bug啊!!!!
今天我把ctd这个标注数据集转为四边形的四个顶点,发现有问题!!转的四个点挨个画框不是矩形,并且原本没有旋转的数据画出的是矩形,旋转扩充的就是交叉的线,我再看rot的label,发现前面7个点的x不是递增的!!!!所以我确定这里的rot代码有问题。然后一步步查。

def GetRect(points):
    temp = []
    # print(points)
    points.sort(key=takeFirst)
    xs = points[0][0]
    xe = points[len(points) - 1][0]
    points.sort(key=takeSecond)
    ys = points[0][1]
    ye = points[len(points) - 1][1]
    temp = [xs, ys, xe, ye]
    return temp

发现经过这个函数points改变了!!!!!!!!又是list做函数参数!!!!!!!!!都是坑啊!!!这个错误排查不出来,即使我后面写了画点的函数,但是ctd训练需要顺序的,从左上点开始,顺时针标注的!!!难怪我扩充的很多数据反而效果不好!!!一度怀疑人生!!!!

#coding=utf-8
#python 2
import cv2
import numpy as np
import math
import os


def rotation_point(img, angle=15, point=None):
    cols = img.shape[1]
    rows = img.shape[0]
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    heightNew = int(cols * math.fabs(math.sin(math.radians(angle))) + rows * math.fabs(math.cos(math.radians(angle))))
    widthNew = int(rows * math.fabs(math.sin(math.radians(angle))) + cols * math.fabs(math.cos(math.radians(angle))))
    M[0, 2] += (widthNew - cols) / 2
    M[1, 2] += (heightNew - rows) / 2

    img = cv2.warpAffine(img, M, (widthNew, heightNew))
    a = M[:, :2]  ##a.shape (2,2)
    b = M[:, 2:]  ###b.shape(2,1)
    b = np.reshape(b, newshape=(1, 2))
    a = np.transpose(a)
    len_1 = len(point)
    point = np.reshape(point, newshape=(len_1, 2))  # point = np.reshape(point, newshape=(len(point) * 4, 2))
    point = np.dot(point, a) + b
    point = np.reshape(point, newshape=(len_1, 2))
    point = point.astype(int)
    return img, point


def label2Mypt(path_label_txt, vv_Mypt=[]):
    vv_Mypt[:] = []
    with open(path_label_txt, 'r') as f:
        lines = f.readlines()
    # vv_Mypt = []
    for line in lines:
        v_val = line.split(',')
        v_pt = []
        for index in range(14):
            pt = ()
            pt_x = int(v_val[0]) + int(v_val[4 + index * 2])
            pt_y = int(v_val[1]) + int(v_val[4 + index * 2 + 1])
            pt = (pt_x, pt_y)
            v_pt.append(pt)
        vv_Mypt.append(v_pt)


def draw_circle(img, vv_Mypt):
    for v_pt in vv_Mypt:
        for pt in v_pt:
            cv2.circle(img, pt, 4, (0, 255, 255), 3)
    # cv2.imshow('img',img)
    # cv2.waitKey(0)
    return img


def draw_circle_rot(img, vv_Mypt):
    for v_pt in vv_Mypt:
        for pt in v_pt:
            cv2.circle(img, pt, 4, (255, 0, 0), 3)
    cv2.imshow('img_rot', img)
    cv2.waitKey(0)



def vv_Mypt2point(vv_pt):
    vv_new_pt = []
    for v_pt in vv_pt:
        for pt in v_pt:
            v_pt = []
            x = pt[0]
            y = pt[1]
            v_pt.append(x)
            v_pt.append(y)
            vv_new_pt.append(v_pt)
    return vv_new_pt


def point2vv_Mypt(vv_pt):
    vv_My_pt = []
    v_tmp = []
    for idex, v_val in enumerate(vv_pt):
        x = v_val[0]
        y = v_val[1]
        pt = ()
        pt = (x, y)
        v_tmp.append(pt)
        if (idex + 1) % 14 == 0:
            v_tmp_2 = v_tmp[:]
            vv_My_pt.append(v_tmp_2)
            v_tmp[:] = []

    return vv_My_pt


def takeFirst(elem):
    return elem[0]


def takeSecond(elem):
    return elem[1]

# def GetRect(points):
#     temp = []
#     # print(points)
#     points.sort(key=takeFirst)
#     xs = points[0][0]
#     xe = points[len(points) - 1][0]
#     points.sort(key=takeSecond)
#     ys = points[0][1]
#     ye = points[len(points) - 1][1]
#     temp = [xs, ys, xe, ye]
#     return temp

def GetRect(points_t):
    temp = []
    points = points_t[:]  ####之前没有这么深度复制,导致points_t传出去的时候值也改变了#################
    points.sort(key=takeFirst)
    xs = points[0][0]
    xe = points[len(points) - 1][0]
    points.sort(key=takeSecond)
    ys = points[0][1]
    ye = points[len(points) - 1][1]
    temp = [xs, ys, xe, ye]
    return temp


def vv_Mypt2label_sub(path_save, v_Mypt):
    rect_waijie = GetRect(v_Mypt)
    str_tmp = ""
    for val in rect_waijie:
        str_tmp += str(val) + ","
    for val in v_Mypt:
        x = int(val[0])
        y = int(val[1])
        str_x = str(x - rect_waijie[0])
        str_y = str(y - rect_waijie[1])
        str_tmp += str_x + ',' + str_y + ','
    str_tmp = str_tmp[0:len(str_tmp) - 1] + '\n'
    with open(path_save, 'aw') as f:
        f.write(str_tmp)


def vv_Mypt2label(path_save, vv_Mypt):
    for v_pt in vv_Mypt:
        vv_Mypt2label_sub(path_save, v_pt)


def generate_rot():
    path_root = '/media/data_2/everyday/0524/ctd_test_rot/'
    path_img = path_root + 'img'
    path_label = path_root + 'label'
    path_rot_img = './img_rot'
    path_rot_label = './label_rot'
    os.mkdir(path_rot_img)
    os.mkdir(path_rot_label)

    list_img = os.listdir(path_img)
    angles = [15, -15, 5, -5, 10, -10]
    cnt = 0
    for img_name in list_img:
        cnt += 1
        print("cnt=%d:::%s" % (cnt, img_name))
        path_img_name = path_img + "/" + img_name
        img = cv2.imread(path_img_name)
        path_label_txt = path_label + "/" + img_name.replace('.jpg', '.txt')
        if not os.path.exists(path_label_txt):
            continue
        vv_mypt = []
        label2Mypt(path_label_txt, vv_mypt)
        vv_pt = vv_Mypt2point(vv_mypt)
        for angle in angles:
            name_rot_img = img_name.replace('.jpg', '_' + str(angle) + '.jpg')
            name_rot_txt = name_rot_img.replace('.jpg', '.txt')
            path_save_rot_img = path_rot_img + '/' + name_rot_img
            path_save_rot_txt = path_rot_label + '/' + name_rot_txt
            img_rot, v_pt = rotation_point(img, angle, vv_pt)
            vv_Mypt_1 = point2vv_Mypt(v_pt)
            vv_Mypt2label(path_save_rot_txt, vv_Mypt_1)
            cv2.imwrite(path_save_rot_img, img_rot)


def draw_img_label():
    path_root = "/media/data_2/everyday/0524/test_data/"
    path_img = path_root + 'img'
    path_label = path_root + 'label'
    path_draw = path_root + 'draw_rot_circle'
    os.mkdir(path_draw)

    list_img = os.listdir(path_img)
    cnt = 0
    for img_name in list_img:
        cnt += 1
        print("cnt=%d:::%s" % (cnt, img_name))
        path_img_name = path_img + "/" + img_name
        img = cv2.imread(path_img_name)
        path_label_txt = path_label + "/" + img_name.replace('.jpg', '.txt')
        if not os.path.exists(path_label_txt):
            continue
        vv_mypt = []
        label2Mypt(path_label_txt, vv_mypt)
        img_draw = draw_circle(img, vv_mypt)
        cv2.imwrite(path_draw + '/' + img_name, img_draw)
        height, width = img.shape[:2]
        for v_pt in vv_mypt:
            for pt in v_pt:
                x = pt[0]
                y = pt[1]
                if x<0 or x>=width or y<0 or y>=height:
                    print("**************************************")
                    print ("%s pt out img!!!!"%(img_name))
                    print("pt:"),
                    print(pt)
                    print("shape:"),
                    print(img.shape)
                    print("**************************************")
                    while True:
                        pass


def ctd_label2icdar2015():
    path_root = "/media/data_1/Yang/project/2019/project/PSENet/data/CTW1500/train/"
    path_img = path_root + 'text_image'
    path_label = path_root + 'text_label_curve'

    path_dicdar2015_img = path_root + 'icdar2015_my/Challenge4/ch4_training_images/'
    path_dicdar2015_gt = path_root + 'icdar2015_my/Challenge4/ch4_training_localization_transcription_gt/'
    os.makedirs(path_dicdar2015_img)
    os.makedirs(path_dicdar2015_gt)
    # os.mkdir(path_dicdar2015)

    list_img = os.listdir(path_img)
    cnt = 0
    for img_name in list_img:
        cnt += 1
        print("cnt=%d:::%s" % (cnt, img_name))
        path_img_name = path_img + "/" + img_name
        img = cv2.imread(path_img_name)
        path_label_txt = path_label + "/" + img_name.replace('.jpg', '.txt')
        if not os.path.exists(path_label_txt):
            continue
        vv_mypt = []
        label2Mypt(path_label_txt, vv_mypt)
        #img_draw = draw_circle(img, vv_mypt)
        l_tmp = [0,6,7,13]
        v_str_out = []
        for v_pt in vv_mypt:
            str_out=''
            for id in l_tmp:
                pt = v_pt[id]
                str_out += str(pt[0]) + ',' + str(pt[1]) + ','
            str_out += '###\n'
            v_str_out.append(str_out)

        img_new_name = 'img_' + str(cnt) + '.jpg'
        txt_new_name = 'gt_img_' + str(cnt) + '.txt'
        cv2.imwrite(path_dicdar2015_img+img_new_name,img)
        with open(path_dicdar2015_gt+txt_new_name,'w') as f:
            for val in v_str_out:
                f.write(val)


if __name__ == "__main__":

    #img(已有)
    #label(已有)
    #img_rot(生成)
    #label_rot(生成)
    #generate_rot()

    # draw and check
    #draw_img_label()

    ctd_label2icdar2015()

##下面是遇到的坑###########################################################################

def point2vv_Mypt(vv_pt):
    vv_My_pt = []
    v_tmp = []
    for idex, v_val in enumerate(vv_pt):
        x = v_val[0]
        y = v_val[1]
        pt = ()
        pt = (x, y)
        v_tmp.append(pt)
        if (idex + 1) % 14 == 0:
            v_tmp_2 = v_tmp[:]
            vv_My_pt.append(v_tmp_2)
            v_tmp[:] = []
    return vv_My_pt

一开始这么写的,得到的vv_My_pt全是空的
  if (idex + 1) % 14 == 0
            vv_My_pt.append(v_tmp)
            v_tmp[:] = []  
def label2Mypt(path_label_txt, vv_Mypt=[]):
    vv_Mypt[:] = []
    with open(path_label_txt, 'r') as f:
        lines = f.readlines()
    # vv_Mypt = []
    for line in lines:
        v_val = line.split(',')
        v_pt = []
        for index in range(14):
            pt = ()
            pt_x = int(v_val[0]) + int(v_val[4 + index * 2])
            pt_y = int(v_val[1]) + int(v_val[4 + index * 2 + 1])
            pt = (pt_x, pt_y)
            v_pt.append(pt)
        vv_Mypt.append(v_pt)

        一开始这么写 vv_Mypt根本传不出去
    def label2Mypt(path_label_txt, vv_Mypt=[]):
    vv_Mypt = []
    with open(path_label_txt, 'r') as f:
        lines = f.readlines()
    # vv_Mypt = []
    for line in lines:
        v_val = line.split(',')
        v_pt = []
        for index in range(14):
            pt = ()
            pt_x = int(v_val[0]) + int(v_val[4 + index * 2])
            pt_y = int(v_val[1]) + int(v_val[4 + index * 2 + 1])
            pt = (pt_x, pt_y)
            v_pt.append(pt)
        vv_Mypt.append(v_pt)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值