数据处理之数据增强

  数据增强的由来

    在训练模型的时候,我们构建好了网络,只差数据了,这时搜集数据就很重要。当样本数据太少,我们怎么办? 

    这时数据增强就派上用处啦!!!

数据增强分类

线上增强和线下增强

    线上增强:在读取图片,导入网络的时候进行随机的数据增强(比如图片的翻转、旋转....)

    线下增强:将收集到的数据先进行处理,比如一张图片经过翻转、折叠、调整亮度、旋转等,可以增加多张相似的图片

    本人现在主要用线下增强的方式(主要是线上增强还熟悉,不好意思)

数据增强的作用

     数据增强可以有效解决过拟合问题(增加了数据量,当然学习到的特征更多),增强模型的泛化能力。

代码来啦!!!

这是已经打标后的数据进行数据增强的方式(已经有img和xml格式的数据啦)

需要修改的地方:

1.原始的img与xml文件路径

IMG_DIR = "./dataset/JPEGImages"
       XML_DIR = "./dataset/Annotations"

2.存储增强后的img和xml文件夹路径
        AUG_XML_DIR = "./Annotations"

AUG_IMG_DIR = "./JPEGImages"  

3.每张影像增强的数量
       AUGLOOP = 10

4.iaa.Sequential:需要更多的数据增强效果的,可以去查询库

5.这是png图片的处理,若是jpg,请把代码中的所有png改成jpg即可

import xml.etree.ElementTree as ET
import pickle
import os
from os import getcwd
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as plt
import imgaug as ia
from imgaug import augmenters as iaa


ia.seed(1)


def read_xml_annotation(root, image_id):
    in_file = open(os.path.join(root, image_id))
    tree = ET.parse(in_file)
    root = tree.getroot()
    bndboxlist = []

    for object in root.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        xmin = int(bndbox.find('xmin').text)
        xmax = int(bndbox.find('xmax').text)
        ymin = int(bndbox.find('ymin').text)
        ymax = int(bndbox.find('ymax').text)
        # print(xmin,ymin,xmax,ymax)
        bndboxlist.append([xmin, ymin, xmax, ymax])
        # print(bndboxlist)

    bndbox = root.find('object').find('bndbox')
    return bndboxlist


# (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603)
def change_xml_annotation(root, image_id, new_target):
    new_xmin = new_target[0]
    new_ymin = new_target[1]
    new_xmax = new_target[2]
    new_ymax = new_target[3]

    in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    xmlroot = tree.getroot()
    object = xmlroot.find('object')
    bndbox = object.find('bndbox')
    xmin = bndbox.find('xmin')
    xmin.text = str(new_xmin)
    ymin = bndbox.find('ymin')
    ymin.text = str(new_ymin)
    xmax = bndbox.find('xmax')
    xmax.text = str(new_xmax)
    ymax = bndbox.find('ymax')
    ymax.text = str(new_ymax)
    tree.write(os.path.join(root, str("%06s" % (str(id) + '.xml'))))


def change_xml_list_annotation(root, image_id, new_target, saveroot, id):
    in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    elem = tree.find('filename')
    id = str(id)
    elem.text = str("%06s" % id) + '.png'
    xmlroot = tree.getroot()
    index = 0

    for object in xmlroot.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        # xmin = int(bndbox.find('xmin').text)
        # xmax = int(bndbox.find('xmax').text)
        # ymin = int(bndbox.find('ymin').text)
        # ymax = int(bndbox.find('ymax').text)

        new_xmin = new_target[index][0]
        new_ymin = new_target[index][1]
        new_xmax = new_target[index][2]
        new_ymax = new_target[index][3]

        xmin = bndbox.find('xmin')
        xmin.text = str(new_xmin)
        ymin = bndbox.find('ymin')
        ymin.text = str(new_ymin)
        xmax = bndbox.find('xmax')
        xmax.text = str(new_xmax)
        ymax = bndbox.find('ymax')
        ymax.text = str(new_ymax)

        index = index + 1

    tree.write(os.path.join(saveroot, str("%06s" %id) + '.xml'))


def mkdir(path):
    # 去除首位空格
    path = path.strip()
    # 去除尾部 \ 符号
    path = path.rstrip("\\")
    # 判断路径是否存在
    # 存在     True
    # 不存在   False
    isExists = os.path.exists(path)
    # 判断结果
    if not isExists:
        # 如果不存在则创建目录
        # 创建目录操作函数
        os.makedirs(path)
        print(path + ' 创建成功')
        return True
    else:
        # 如果目录存在则不创建,并提示目录已存在
        print(path + ' 目录已存在')
        return False


if __name__ == "__main__":
    # 原始的img与xml文件路径
    IMG_DIR = "./dataset/JPEGImages"
    XML_DIR = "./dataset/Annotations"
    # 存储增强后的XML文件夹路径
    AUG_XML_DIR = "./Annotations"
    try:
        shutil.rmtree(AUG_XML_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_XML_DIR)
    # 存储增强后的影像文件夹路径
    AUG_IMG_DIR = "./JPEGImages"  
    try:
        shutil.rmtree(AUG_IMG_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_IMG_DIR)
    # 每张影像增强的数量
    AUGLOOP = 10

    boxes_img_aug_list = []
    new_bndbox = []
    new_bndbox_list = []

    # 影像增强
    seq = iaa.Sequential([
        iaa.Fliplr(0.5),  # 对50%的图像做镜像翻转
        iaa.ContrastNormalization((0.75,1.5),per_channel=True),
        iaa.Crop(percent=(0, 0.1),keep_size=True),
        iaa.Multiply((1.2, 1.5)),  # 改变亮度
        iaa.GaussianBlur(sigma=(0, 3.0)),  # iaa.GaussianBlur(0.5),
        iaa.Affine(
            translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
            scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
            rotate=(-30, 30)
        )  # 对一部分图像做仿射变换, rotate旋转±30度之间, scale图像缩放为80%到95%之间, translate_px 独立地在x轴和y轴上将图像平移到15像素
    ])
    # os.walk() 方法用于通过在目录树中游走输出在目录中的文件名,向上或者向下
    # root所指的是当前正在遍历的这个文件夹的本身的地址
    # sub_folders 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
    # files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
    for root, sub_folders, files in os.walk(XML_DIR):
        for name in files:
            bndbox = read_xml_annotation(XML_DIR, name)
            shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)
            shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.png'), AUG_IMG_DIR)

            for epoch in range(AUGLOOP):
                seq_det = seq.to_deterministic()  # 保持坐标和图像同步改变,而不是随机
                # 读取图片
                img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.png'))
                # sp = img.size
                img = np.asarray(img)
                # bndbox 坐标增强
                for i in range(len(bndbox)):
                    bbs = ia.BoundingBoxesOnImage([
                        ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
                    ], shape=img.shape)

                    bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
                    boxes_img_aug_list.append(bbs_aug)

                    # new_bndbox_list:[[x1,y1,x2,y2],...[],[]]
                    n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
                    n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
                    n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
                    n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
                    if n_x1 == 1 and n_x1 == n_x2:
                        n_x2 += 1
                    if n_y1 == 1 and n_y2 == n_y1:
                        n_y2 += 1
                    if n_x1 >= n_x2 or n_y1 >= n_y2:
                        print('error', name)
                    new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])
                # 存储变化后的图片
                image_aug = seq_det.augment_images([img])[0]
                path = os.path.join(AUG_IMG_DIR, str(len(files)) + str(name[:-4]) + str(epoch * 250) + '.png')
                image_auged = bbs.draw_on_image(image_aug, thickness=0)
                Image.fromarray(image_auged).convert('RGB').save(path)

                # 存储变化后的XML
                change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,
                                           str(len(files)) + str(name[:-4]) + str(epoch * 250))
                print(str(len(files)) + str(name[:-4]) + str(epoch * 250) + '.png')
                new_bndbox_list = []

想要具体了解增强方式,可以参考博文数据增强(Data Augmentation) - 知乎

结尾:代码我只修改数据增强的部分,源码真的不记得哪里来的了,原作者看到可以私聊我,我改成转载。

  • 2
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆珝

您的打赏是我的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值