TT100K数据集增强方法:替换法、黏贴法以及常规扩充方法

 TT100K数据集,是由清华与腾讯的联合实验室整理并公布的,提供的 10 万张图像包含了 30000 个交通标志,图像来源于由 6 个像素很高的广角单反相机在中国的多个城市拍摄的腾讯街景全景图,拍摄地点的光照条件、天气条件有所不同。原始的街景全景图分辨率为 8192x2048,再将全景图裁剪分为四份,最终数据集的尺寸为 2048x2048。TT-100K 数据集所含交通标志的类别较为全面,整个数据集共出现 221 种不同的类别,标注了的有128类。

链接下载地址:TT100K

数据集处理脚本:TT100K2VOC

使用YOLO格式进行训练的可以参考我的博客:VOC2YOLO

下面主要进行分享TT100K数据集的一些增强方法,包含3种,代码以及效果展示如下:

  1. 替换法
import os
import json
import shutil
import tqdm
import cv2
import numpy as np
import random
from lxml import etree as ET
from xml.dom import minidom
from random import sample
from skimage import exposure, transform
import math
from PIL import Image, ImageStat, ImageEnhance
from DataAugmentForObejctDetection import DataAugmentForObjectDetection
from xml_helper import generate_xml, parse_xml


# 将解析的目标信息转化为xml文件
def edit_xml(objects, id, dir, label45=True):
    """
    objects: 实例
    id:图片名
    dir:保存的路径
    label45:是否只写入45类的标签
    """
    classesCount = json.loads(open(classesPath).read())
    className45 = classesCount.keys()
    save_xml_path = os.path.join(dir, "%s.xml" % id)  # xml

    root = ET.Element("annotation")
    # root.set("version", "1.0")  
    folder = ET.SubElement(root, "folder")
    folder.text = "none"
    filename = ET.SubElement(root, "filename")
    filename.text = id + ".jpg"
    source = ET.SubElement(root, "source")
    source.text = "none"
    owner = ET.SubElement(root, "owner")
    owner.text = "danchaofan"
    size = ET.SubElement(root, "size")
    width = ET.SubElement(size, "width")
    width.text = str(2048)
    height = ET.SubElement(size, "height")
    height.text = str(2048)
    depth = ET.SubElement(size, "depth")
    depth.text = "3"
    segmented = ET.SubElement(root, "segmented")
    segmented.text = "0"
    for obj in objects:
        # 若实例不在45类内则不写入
        if (label45 == True) and (obj["category"] not in className45):
            continue
        object = ET.SubElement(root, "object")
        name = ET.SubElement(object, "name")  # number
        name.text = obj["category"]

        # meaning = ET.SubElement(object, "meaning")  # name
        # meaning.text = inf_value[0]
        pose = ET.SubElement(object, "pose")
        pose.text = "Unspecified"
        truncated = ET.SubElement(object, "truncated")
        truncated.text = "0"
        difficult = ET.SubElement(object, "difficult")
        difficult.text = "0"
        bndbox = ET.SubElement(object, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(int(obj["bbox"]["xmin"]))
        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(int(obj["bbox"]["ymin"]))
        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(int(obj["bbox"]["xmax"]))
        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(int(obj["bbox"]["ymax"]))
    tree = ET.ElementTree(root)
    tree.write(save_xml_path, encoding="UTF-8", xml_declaration=True)
    root = ET.parse(save_xml_path) 
    file_lines = minidom.parseString(ET.tostring(root, encoding="Utf-8")).toprettyxml(
        indent="\t") 
    file_line = open(save_xml_path, "w", encoding="utf-8")  
    file_line.write(file_lines)
    file_line.close()


# 获取指定文件夹的图片全部id
def getDirId(dir):
    names = os.listdir(dir)
    ids = []
    for name in names:
        # path = os.path.join(dir, name)
        # img  = cv2.imread(path)
        # w, h, c = img.shape
        # 只统计图片格式的文件
        if name.endswith(".jpg") or name.endswith(".xml"):
            # ids["%s" % name.split(".")[0]] = [w, h, c]
            ids.append(name.split(".")[0])
    return ids


# 找出训练集和测试集中的不在45类的标注图片的id,只要图片里面有一个实例是45类标签里的,就排除
# def is_tt45(classNames, objects):
#     """
#         objects: dict,实例
#         classesPath:str,类别json文件的地址
#     """
#     for obj in objects:
#         name = obj["category"]
#         if name in classNames:
#             return True
#     return False

# 找出训练集和测试集中的不在45类的标注图片的id,只要图片里面存在非45实例就满足
def is_tt45(classNames, objects):
    flag = True
    for obj in objects:
        text = obj["category"]
        for key in classNames:
            flag1 = False
            if key == text:
                flag1 =True
                break
        if flag1 == False:
            flag = False
            break
    return flag


def marksCount(annos, classesPath, imgPath):
    """
    annos: dict,读取的json格式的标签
    classesPath:str,类别json文件的地址
    imgPath:要统计的图片文件夹
    """
    annos = annos
    # 载入类别文件为一个字典,并将其初始化为0
    classesCount = json.loads(open(classesPath).read())
    className45 = classesCount.keys()
    for key in classesCount:
        classesCount[key] = 0
    ids = annos["imgs"].keys()
    trainIds = getDirId(imgPath)

    for id in tqdm.tqdm(ids):
        if id in trainIds:
            objects = annos["imgs"][id]["objects"]
            for obj in objects:
                if obj['category'] in className45:
                    classesCount[obj['category']] += 1
    for key, value in classesCount.items():
        # if classesCount[key] <= 100 and classesCount[key] <=500:
        #     print(key, ":", value)
        print(key, ":", value)
    return classesCount


def Not45marksCount(annos, classesPath, imgPath):
    """
    annos: dict,读取的json格式的标签
    classesPath:str,类别json文件的地址
    imgPath:要统计的图片文件夹
    """
    annos = annos
    # 载入类别文件为一个字典,并将其初始化为0
    classesCount = json.loads(open(classesPath).read())
    className45 = classesCount.keys()
    for key in classesCount:
        classesCount[key] = 0
    ids = annos["imgs"].keys()
    trainIds = getDirId(imgPath)
    count = 0
    for id in tqdm.tqdm(ids):
        if id in trainIds:
            objects = annos["imgs"][id]["objects"]
            for obj in objects:
                if obj['category'] in className45:
                    classesCount[obj['category']] += 1
                else:
                    count += 1

    for key, value in classesCount.items():
        # if classesCount[key] <= 100 and classesCount[key] <=500:
        #     print(key, ":", value)
        print(key, ":", value)
    print('others labels: ' + str(count))
    return classesCount


# def Not45marksCount(annos, classNames):
#     """
#     annos: dict,读取的json格式的标签
#     """
#     annos = annos
#     ids = annos["imgs"].keys()
#     trainIds = getDirId(os.path.join(tt100kPath, "data/otherimgs"))
#     count = 0
#     for id in tqdm.tqdm(ids):
#         if id in trainIds:
#             objects = annos["imgs"][id]["objects"]
#             for object in objects:
#                 if object not in classNames:
#                     count += 1
#     return count

# 判断是否达到扩充条件,若需要扩充则返回Ture
def needAug(classesDict, threshold):
    """
    返回还是否需要p照片
    classesDict:dict
    """
    flag = False
    # 当有一个小于200时说明需要扩充
    for k, v in classesDict.items():
        if classesDict[k] < threshold:
            flag = True
    return flag


# 返回打乱顺序后的字典
def randomDict(classesDict):
    keys = list(classesDict.keys())
    random.shuffle(keys)
    new_dic = {}
    for key in keys:
        new_dic[key] = classesDict[key]
    return new_dic


# 判断一张图片中需要扩充的实例类别是否占总的个数超80%以上,避免扩充太多其他类别
def needAug3(names, classesDict, threshold=0.8):
    flag = False
    count = 0
    for n in names:
        if n in classesDict.keys():
            if classesDict[n] < 500:
                count += 1
    # 当一张图片中需要扩充的实例类别占总的个数超80%以上时才进行扩充,否则跳过该图片
    if (count/len(names)) > threshold:
        flag = True
    return flag


# 调整mark亮度
def brightness(img):
    stat = ImageStat.Stat(img)
    rgb = stat.mean
    return math.sqrt(0.241*(rgb[0]**2) + 0.691*(rgb[1]**2) + 0.068 * (rgb[2]**2))


# p一个标志
def paste_one_mask(obj_anno, mark, im2):
    width = round(obj_anno['bbox']['xmax']) - int(obj_anno['bbox']['xmin'])
    height = round(obj_anno['bbox']['ymax']) - int(obj_anno['bbox']['ymin'])
    img1 = mark.resize((width, height))
    # 调整图片亮度,使其与背景图片相符
    img1 = ImageEnhance.Brightness(img1).enhance(brightness(mark) / brightness(im2))
    # im2.paste(img1, (int(obj_anno['bbox']['xmin']), int(obj_anno['bbox']['ymin']),
    #                  int(obj_anno['bbox']['xmin']) + width, int(obj_anno['bbox']['ymin']) + height))
    im2.paste(img1, (int(obj_anno['bbox']['xmin']), int(obj_anno['bbox']['ymin'])), mask=img1)
    return im2


# 同样p一个标志,输入的obj_anno格式不同而已
def paste_one_mask2(obj_anno, mark, im2):
    width = round(obj_anno[2]) - int(obj_anno[0])
    height = round(obj_anno[3]) - int(obj_anno[1])
    img1 = mark.resize((width, height))
    # 调整图片亮度,使其与背景图片相符
    img1 = ImageEnhance.Brightness(img1).enhance(brightness(mark) / brightness(im2))
    # im2.paste(img1, (int(obj_anno['bbox']['xmin']), int(obj_anno['bbox']['ymin']),
    #                  int(obj_anno['bbox']['xmin']) + width, int(obj_anno['bbox']['ymin']) + height))
    im2.paste(img1, (int(obj_anno[0]), int(obj_anno[1])), mask=img1)
    return im2


# 计算粘贴的两个标志重合度,太高则丢弃重新选
def cal_iou(boxA, boxB):
    '''
    boxA, boxB:a list,[xmin, ymin, xmax, ymax]
    '''

    # 找出靠前的框的右下角坐标A和靠后的框的左上角坐标B,若B在A之后,则说明两者没有交集
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    if xB <= xA or yB <= yA:
        return 0.0

    # 计算重叠区域面积
    interArea = (xB - xA + 1) * (yB - yA + 1)

    # 计算boxB框的面积
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    # iou = interArea / float(boxAArea + boxBArea - interArea)
    # 计算交叠面积占boxB框的比例
    iou = interArea / float(boxBArea)

    # return the intersection over union value
    return iou


if __name__ == '__main__':
    tt100kPath = 'E:\dataset\TT100k'
    annoPath = os.path.join(tt100kPath, 'data', 'annotations.json')
    classesPath = os.path.join(tt100kPath, 'data', 'TT100K_VOC_classes.json')
    marksPath = os.path.join(tt100kPath, 'data', 'marks', 'pad-all')

    annos = json.loads(open(annoPath).read())
    classNames = json.load(open(classesPath, 'r')).keys()
    # 这里设置train图片的位置
    allimgPath = os.path.join(tt100kPath, "data/train/")
    trainIds = getDirId(allimgPath)
    #  读取annotations.json所有的图片id
    ids = annos["imgs"].keys()

    # 创建文件夹
    train_labels = os.path.join(tt100kPath, "data/TrainLabels")
    others_labels = os.path.join(tt100kPath, "data/Not45Labels")
    others_labels2 = os.path.join(tt100kPath, "data/Not45Labels_after_replace")
    others_labels3 = os.path.join(tt100kPath, "data/ObjMoreThan5Labels_after_replace")
    others_labels4 = os.path.join(tt100kPath, "data/LabelsAfterPaste")
    others_imgs = os.path.join(tt100kPath, "data/Not45Images")
    others_imgs2 = os.path.join(tt100kPath, "data/Not45Images_after_replace")
    others_imgs3 = os.path.join(tt100kPath, "data/ObjMoreThan5Images_after_replace")
    others_imgs4 = os.path.join(tt100kPath, "data/ImagesAfterPaste")
    for filepath in [train_labels, others_labels, others_labels2, others_labels3, others_labels4, others_imgs, others_imgs2,
                     others_imgs3, others_imgs4]:
        if not os.path.exists(filepath):
            os.makedirs(filepath)
    Not_TT45_list = []
    # 将json文件转化为xml文件,并将不存在45类标签的图片移动到otherimgs文件夹备用,该部分的图片的标签文件也单独存放otherLabels
    for id in tqdm.tqdm(ids):
        #  若json 中的ID图片train文件夹
        if id in trainIds:
            objects = annos["imgs"][id]["objects"]
            flag = is_tt45(classNames, objects)
            # 若该图片的实例存在非45类的,则单独存放到otherLabels文件夹,并将图片移动到otherimgs里面
            if flag is False:
                Not_TT45_list.append(id + '\n')
                edit_xml(objects, id, dir=others_labels, label45=True)
                imgPath = os.path.join(allimgPath, id + '.jpg')
                shutil.move(imgPath, others_imgs)
            else:
                edit_xml(objects, id, dir=train_labels, label45=True)

    # 将不存在45类标签的图片id记录在txt文本中
    with open(os.path.join(tt100kPath, 'data', "Not_TT45_list_train.txt"), "w") as f:
        f.writelines(Not_TT45_list)
    # 处理后,文件夹只剩下包含45类标签的图片,其余的存放到otherimgs文件夹
    # 下面统计剩余的图片里面包含的45类数量
    print('只包含45类的图片各类别统计:')
    # 返回包含45类以及对应数量的字典,并打印出来
    classesCount = marksCount(annos, classesPath, allimgPath)
    print('========================================')
    print('包含其他类别的图片各类别(45类)统计:')
    # 返回剩余图片中45类别以及对应数量的字典、非45类的总数量并打印出来
    othersCount = Not45marksCount(annos, classesPath, others_imgs)
    # 计算45类的总数量 = classesCount + othersCount,
    print('原始训练集中所有的45类实例数统计:')
    allCount = {}
    for key, value in classesCount.items():
        allCount[key] = classesCount[key] + othersCount[key]
    for key, value in allCount.items():
        print(key, ":", value)
    print('========================================')
    # print('Not45marksCount: ' + Not45marksCount(annos, classesPath, others_imgs))
    # =========================根据实例数量划分数据集并写到txt文件======================
    # 更新此时的训练集id集
    trainIds = getDirId(allimgPath)
    ob_less_than1 = []
    ob_more_than5 = []
    ob_between_1_and_5 = []
    for id in tqdm.tqdm(ids):
        #  若json 中的ID图片train文件夹
        if id in trainIds:
            objects = annos["imgs"][id]["objects"]
            if len(objects) <= 1:
                ob_less_than1.append(id)
            elif len(objects) >= 5:
                ob_more_than5.append(id)
            elif 1 < len(objects) < 5:
                ob_between_1_and_5.append(id)
    print("实例数大于5的图片数量:{};实例数小于等于1的图片数量:{};实例数在1到5之间的图片数量:{}".format(len(ob_more_than5),
                                                                     len(ob_less_than1), len(ob_between_1_and_5)))
    with open(os.path.join(tt100kPath, 'data', "ob_less_than1.txt"), "w") as f:
        for id in ob_less_than1:
            f.write(id + '\n')
    with open(os.path.join(tt100kPath, 'data', "ob_more_than5.txt"), "w") as f:
        for id in ob_more_than5:
            f.write(id + '\n')
    with open(os.path.join(tt100kPath, 'data', "ob_between_1_and_5.txt"), "w") as f:
        for id in ob_between_1_and_5:
            f.write(id + '\n')
    # ================================================================================================ #
    # 筛选出小于200个实例和少于600的部分,读取需要扩充的标签图片
    lessThan200 = {}
    lessThan400 = {}
    lessThan600 = {}
    marksFor21 = {}
    marksFor28 = {}
    for key, value in allCount.items():
        if allCount[key] <= 200:
            marksFor21[key] = Image.open(os.path.join(marksPath, key + '.png'))
            lessThan200[key] = allCount[key]
    sampleImageIds = getDirId(others_imgs)

    # 第一次置换,从含有非45类的文件夹读取一张照片,替换存在的实例,每个实例扩充到145
    for id in tqdm.tqdm(sampleImageIds):
        # 若已经增广完毕,则提前退出
        if not needAug(lessThan200, 145):
            print('增强完成')
            break
        sampleImg = Image.open(os.path.join(others_imgs, id + '.jpg'))
        objs = annos["imgs"][id]["objects"]
        labels = []
        # 对于每个实例,随机找出一个标签p上
        for obj in objs:
            # 设置flag参数判断能是否找到同类型的标签
            flag = False
            objName = obj['category']
            # 若该实例已经是属于45类中,则跳过此实例
            if objName in classNames:
                continue
            # 否则找出尚未达到150实例数的且在同一张图上没有使用过的同类型的标签
            # 先打乱字典,让选取的标识随机一点
            lessThan200 = randomDict(lessThan200)
            for key, value in lessThan200.items():
                if (lessThan200[key] < 145) and (key not in labels) and (key[0] == objName[0]):
                    flag = True
                    # 若存在,则将新的名称更新标签值
                    newName = key
                    obj['category'] = newName
                    labels.append(key)
                    # 返回替换完标签的图片
                    sampleImg1 = paste_one_mask(obj, marksFor21[key], sampleImg)
                    # 更新实例数
                    lessThan200[key] += 1
                    # 成功替换一个实例,跳出循环,选择该图片下一个实例
                    break
            # 若flag为false证明已经找不到没有使用过的同类型的标签,则随机使用一个标志
            if flag == False:
                # 先打乱字典,让选取的标识随机一点
                lessThan200 = randomDict(lessThan200)
                for key, value in lessThan200.items():
                    if (lessThan200[key] < 145) and (key not in labels):
                        # 若存在,则将新的名称更新标签值
                        newName = key
                        obj['category'] = newName
                        labels.append(key)
                        # 返回替换完标签的图片
                        sampleImg1 = paste_one_mask(obj, marksFor21[key], sampleImg)
                        # 更新实例数
                        lessThan200[key] += 1
                        # 成功替换一个实例,跳出循环,选择该图片下一个实例
                        break

        # 写入新的标签和图片
        edit_xml(objs, id, dir=others_labels2, label45=True)
        sampleImg.save(os.path.join(others_imgs2, id + '.jpg'))

    # 若全部增强完毕,则退出,更新此时的实例数
    print('========================================')
    print('第一次全部替换完毕,更新此时的实例数:')
    for key, value in lessThan200.items():
        allCount[key] = lessThan200[key]
        print(key + ':' + str(lessThan200[key]))

    SecondReplaceIds = []
    # 第二次置换,从实例数大于5的图片读取一张照片,替换存在的实例,每个实例扩充到200
    for id in tqdm.tqdm(ob_more_than5):
        # 若已经增广完毕,则提前退出
        if not needAug(lessThan200, 200):
            print('增强完成')
            break
        sampleImg = Image.open(os.path.join(allimgPath, id + '.jpg'))
        objs = annos["imgs"][id]["objects"]
        labels = []
        # 对于每个实例,随机找出一个标签p上
        for obj in objs:
            # 设置flag参数判断能是否找到同类型的标签
            flag = False
            objName = obj['category']
            # 若该实例已经是属于需要扩充的类别,则不用p这个标签,但需要更新此时的实例数
            if objName in lessThan200.keys():
                lessThan200[objName] += 1
                continue
            # 否则找出尚未达到200实例数的且在同一张图上没有使用过的同类型的标签
            # 先打乱字典,让选取的标识随机一点
            lessThan200 = randomDict(lessThan200)
            for key, value in lessThan200.items():
                if (lessThan200[key] < 200) and (key not in labels) and (key[0] == objName[0]):
                    flag = True
                    # 若存在,则将新的名称更新标签值
                    newName = key
                    obj['category'] = newName
                    labels.append(key)
                    # 返回替换完标签的图片
                    sampleImg1 = paste_one_mask(obj, marksFor21[key], sampleImg)
                    # 更新实例数
                    lessThan200[key] += 1
                    # 成功替换一个实例,跳出循环,选择该图片下一个实例
                    break
            # 若flag为false证明已经找不到没有使用过的同类型的标签,则随机使用一个标志
            if flag == False:
                # 先打乱字典,让选取的标识随机一点
                lessThan200 = randomDict(lessThan200)
                for key, value in lessThan200.items():
                    if (lessThan200[key] < 200) and (key not in labels):
                        # 若存在,则将新的名称更新标签值
                        newName = key
                        obj['category'] = newName
                        labels.append(key)
                        # 返回替换完标签的图片
                        sampleImg1 = paste_one_mask(obj, marksFor21[key], sampleImg)
                        # 更新实例数
                        lessThan200[key] += 1
                        # 成功替换一个实例,跳出循环,选择该图片下一个实例
                        break
        # 对复制过来替换的图片做一下图像变换
        dataAug = DataAugmentForObjectDetection(rotation_rate=0,
                                                max_rotation_angle=5,
                                                crop_rate=0,
                                                shift_rate=0,
                                                change_light_rate=0.5,
                                                add_noise_rate=0.5,
                                                gaussianblur_rate=0.5,
                                                flip_rate=0,
                                                cutout_rate=0,
                                                cut_out_length=100,
                                                cut_out_holes=2,
                                                cut_out_threshold=0.01)
        # 使用dataAugment后输出的是增强后的图片和边框信息
        xml_path = os.path.join(train_labels, id + '.xml')
        coords, shape = parse_xml(xml_path)
        auged_img, auged_objects = dataAug.dataAugment(cv2.cvtColor(np.asarray(sampleImg), cv2.COLOR_RGB2BGR),
                                                       coords)
        # 写入新的标签和图片
        new_id = id + '_repalce'
        generate_xml(img_name=new_id + '.jpg', coords=auged_objects, img_size=auged_img.shape,
                     out_root_path=others_labels3)
        cv2.imwrite(os.path.join(others_imgs3, new_id + '.jpg'), auged_img)
        # 将本次复制的图片id记录下来
        SecondReplaceIds.append(id)

    # 若全部增强完毕,则退出,更新此时的实例数
    print('========================================')
    print('第二次替换完毕,更新此时的实例数:')
    for key, value in lessThan200.items():
        allCount[key] = lessThan200[key]
        print(key + ':' + str(lessThan200[key]))
    with open(os.path.join(tt100kPath, 'data', 'SecondReplaceIds.txt'), 'w') as sf:
        for id in SecondReplaceIds:
            sf.write(id + '\n')
    # ============================下面进行贴图方法增强======================================
    print('开始贴图')
    # 找出小于400实例数的类别和对应的类别大头贴
    for key, value in allCount.items():
        if (allCount[key] >= 200) and allCount[key] <= 400:
            lessThan400[key] = allCount[key]
            marksFor28[key] = Image.open(os.path.join(marksPath, key + '.png'))
    # 在训练集中选择实例数小于等于1的图片
    for id in tqdm.tqdm(ob_less_than1):
        # 若已经增广完毕,则提前退出
        if not needAug(lessThan400, 400):
            print('黏贴完成')
            break
        sampleImg = Image.open(os.path.join(allimgPath, id + '.jpg'))
        # xml_path = os.path.join(train_labels, id + '.xml')
        # coords, shape = parse_xml(xml_path)
        # objects里面是图片原来就存在的标签,新粘贴的标签不要和他有重叠
        objects = annos["imgs"][id]["objects"]
        labels = []
        # 一张图片贴5张标签
        for n in range(5):
            # 设置flag参数判断能是否找到不同类别的标签
            flag = False
            # 若已经增广完毕,则提前退出
            if not needAug(lessThan400, 400):
                print('黏贴完成')
                break
            # 先打乱字典,让选取的标识随机一点
            lessThan400 = randomDict(lessThan400)
            for key, value in lessThan400.items():
                # 选择小于400的且当前图片没有使用过的标签
                if (lessThan400[key] < 400) and (key not in labels):
                    flag = True
                    mark = marksFor28[key]
                    objName = key
                    # 将标志框的大小限制在20~120之间
                    size = int(20 + random.random() * 100)
                    mark.resize((size, size))
                    # 判断选取的坐标是否合适,若和该图之前的贴图有重叠,则剔除,选取新的坐标点
                    while(True):
                        # 随机选择一个坐标作为标注框的左上角,要限制在0~2048 - size之间
                        x1, y1 = random.randint(0, int(2048 - size)), random.randint(512, int(1536 - size))
                        x2, y2 = int(x1 + size), int(y1 + size)
                        # 统计与之前的标志是否有重合,有则选择另一个点,直到选取的坐标点合适为止
                        ious = []
                        iousum = 0.0
                        for obj in objects:
                            box = [obj['bbox']['xmin'], obj['bbox']['ymin'], obj['bbox']['xmax'], obj['bbox']['ymax']]
                            ious.append(cal_iou([x1, y1, x2, y2], box))
                        for iou in ious:
                            iousum = iousum + iou
                        if iousum < 0.05:
                            break
                    paste_one_mask2((x1, y1, x2, y2), mark, sampleImg)
                    labels.append(objName)
                    new_dict = {'category': objName, 'bbox': {'xmin': x1, 'ymin': y1, 'xmax': x2, 'ymax': y2}}
                    # new_dict['bbox']['xmin'] = x1
                    # new_dict['bbox']['ymin'] = y1
                    # new_dict['bbox']['xmax'] = x2
                    # new_dict['bbox']['ymax'] = y2
                    objects.append(new_dict)
                    # 贴图完成一个,更新实例数,退出当前循环,选择下一个合适的贴图
                    lessThan400[objName] += 1
                    break
            # 若只剩下一个类别需要增强,则增强该类别
            if flag == False:
                # 先打乱字典,让选取的标识随机一点
                lessThan400 = randomDict(lessThan400)
                for key, value in lessThan400.items():
                    if (lessThan400[key] < 400):
                        mark = marksFor28[key]
                        objName = key
                        # 将标志框的大小限制在20~120之间
                        size = int(20 + random.random() * 100)
                        mark.resize((size, size))
                        # 判断选取的坐标是否合适,若和该图之前的贴图有重叠,则剔除,选取新的坐标点
                        while (True):
                            # 随机选择一个坐标作为标注框的左上角,要限制在0~2048 - size之间(为了使贴图不超出页面外)
                            x1, y1 = random.randint(0, int(2048 - size)), random.randint(512, int(1536 - size))
                            x2, y2 = int(x1 + size), int(y1 + size)
                            # 统计与之前的标志是否有重合,有则选择另一个点,直到选取的坐标点合适为止
                            ious = []
                            iousum = 0.0
                            for obj in objects:
                                box = [obj['bbox']['xmin'], obj['bbox']['ymin'], obj['bbox']['xmax'], obj['bbox']['ymax']]
                                ious.append(cal_iou([x1, y1, x2, y2], box))
                            for iou in ious:
                                iousum = iousum + iou
                            if iousum < 0.05:
                                break
                        paste_one_mask2([x1, y1, x2, y2], mark, sampleImg)
                        labels.append(objName)
                        new_dict = {'category': objName, 'bbox': {'xmin': x1, 'ymin': y1, 'xmax': x2, 'ymax': y2}}
                        # new_dict['category'] = objName
                        # new_dict['bbox']['xmin'] = x1
                        # new_dict['bbox']['ymin'] = y1
                        # new_dict['bbox']['xmax'] = x2
                        # new_dict['bbox']['ymax'] = y2
                        objects.append(new_dict)
                        # 贴图完成一个,更新实例数,退出当前循环,选择下一个合适的贴图
                        lessThan400[objName] += 1
                        break
        # 写入新的标签和图片
        edit_xml(objects, id, dir=others_labels4, label45=True)
        # generate_xml(img_name=id + '.jpg', coords=objects, img_size=[2048, 2048, 3],
        #              out_root_path=others_labels4)
        sampleImg.save(os.path.join(others_imgs4, id + '.jpg'))

    # 若全部增强完毕,则退出,更新此时的实例数
    print('========================================')
    print('全部贴图完毕,更新此时的实例数:')
    for key, value in lessThan400.items():
        allCount[key] = lessThan400[key]
    for key, value in allCount.items():
        print(key + ':' + str(allCount[key]))
    json_data = json.dumps(allCount)
    # 将此时的统计信息保存成json格式
    with open(os.path.join(tt100kPath, 'data/allcount.json'), 'w') as f_six:
        f_six.write(json_data)













 使用替换法后效果如下:

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值