一个有效的小目标检测的数据增强方法Mixup及其变体填鸭式

系列文章目录


前言

我们知道目标检测数据集中数据和标签需要一一对应,一旦对图像数据做了增强处理后(目标bbox发生改变),标签也需要做相应的修改。
比较work的数据增强方法:
Mosaic
MixUp
Resize
LetterBox
RandomCrop
RandomFlip
RandomHSV
RandomBlur
RandomNoise
RandomAffine
RandomTranslation
Normalize
ImageToTensor

这些都是很容易就能嵌入到我们的训练框架中,下面介绍一种比较有用的方法,对小目标和目标背景缺乏的场景下涨点明显。也是一种解决样本少,不均衡的方法。

一、增强效果

GitHubDetection_Augmentation
如下图,图片只含有一个目标,我们可以将另外一张图里的目标扣下来,贴到这一张图上去,熟悉目标检测的都知道当我们训练业务场景的数据时,这样增强对模型的泛化能力的提升是很积极的, 下面具体讲解如何实现。
label:1 0.5751 0.3541666666666667 0.28125 0.38
在这里插入图片描述
下图是将目标贴到原图里,并且label文件也保持下来了,第二行开始为新增的三个目标。
1 0.571875 0.35 0.28125 0.38
1 0.2 0.15 0.1875 0.25
2 0.875 0.4666666666666667 0.1875 0.25
3 0.51875 0.775 0.1875 0.25

在这里插入图片描述

二、方法讲解

此方法是在yolo标签格式下完成的,如果你们的数据标签是VOC或coco格式,需先转换成yolo格式,增强之后在转回来。

1. 原图数据

待增强的图像和标签文件
在这里插入图片描述

2. 截取目标roi

python crop_image.py   # 根据bbox截取目标roi,并保存图片

在这里插入图片描述

3. 运行demo.py

import os
import random
from os.path import join
import aug
import Helpers as hp
from util import *

# ###########Pipeline##############
"""
1 准备数据集和yolo格式标签, 如果自己的数据集是voc或coco格式的,先转换成yolo格式,增强后在转回来
2 run crop_image.py  裁剪出目标并保存图片
3 run demo.py   随机将裁剪出目标图片贴到需要增强的数据集上,并且保存增强后的图片集和label文件
"""

base_dir = os.getcwd()

save_base_dir = join(base_dir, 'save_path')

check_dir(save_base_dir)

# imgs_dir = [f.strip() for f in open(join(base_dir, 'sea.txt')).readlines()]
imgs_dir = [os.path.join('fruit', f) for f in os.listdir('fruit') if f.endswith('jpg')]
labels_dir = hp.replace_labels(imgs_dir)
# print(imgs_dir, '\n', labels_dir)

# small_imgs_dir = [f.strip() for f in open(join(base_dir, 'dpj_small.txt')).readlines()]
small_imgs_dir = [os.path.join('fruit_image', f) for f in os.listdir('fruit_image') if f.endswith('jpg')]
random.shuffle(small_imgs_dir)  # 目标图片打乱
# print(small_imgs_dir)

times = 3  # 随机选择增加多少个目标

for image_dir, label_dir in zip(imgs_dir, labels_dir):
    # print(image_dir, label_dir)
    small_img = []
    for x in range(times):
        if small_imgs_dir == []:
            small_imgs_dir = [os.path.join('fruit_image', f) for f in os.listdir('fruit_image') if f.endswith('jpg')]
            random.shuffle(small_imgs_dir)
        small_img.append(small_imgs_dir.pop())
    # print("ok")
    aug.copysmallobjects(image_dir, label_dir, save_base_dir, small_img, times)

aug.py

 new_bboxes = random_add_patches(roi.shape,     # 此函数roi目标贴到原图像上,返回的bbox为roi在原图上的bbox,
                               rescale_labels,  # 并且bbox不会挡住图片上原有的目标
                               image.shape,
                               paste_number=1,  # 将该roi目标复制几次并贴到到原图上
                               iou_thresh=0)    # iou_thresh 原图上的bbox和贴上去的roi的bbox的阈值

当paste_number=1时是第二幅图的结果,当paste_number=2时每个roi目标会复制两张,随机贴在原图上,iou_thresh可以设置目标之间的交并比,,如下图;
在这里插入图片描述

Mixup

此为博客取图,仅作效果展示,运行以下代码可生成下图和对应的label文件
在这里插入图片描述
在这里插入图片描述

import cv2
import os
import random
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom

img_path = 'image/'           # 原始图片文件夹路径
save_path = 'mixup/'       # mixup的图片文件夹路径
xml_path = 'xml/'           # 原始图片对应的标注文件xml文件夹的路径
save_xml = 'mixup_xml/'        # mixup的图片对应的标注文件xml的文件夹路径
img_names = os.listdir(img_path)
img_num = len(img_names)
print('img_num:', img_num)

for imgname in img_names:
    imgpath = img_path + imgname
    if not imgpath.endswith('jpg'):
        continue
    img = cv2.imread(imgpath)
    img_h, img_w = img.shape[0], img.shape[1]
    print(img_h,img_w)

    i = random.randint(0, img_num - 1)
    print('i:', i)
    add_path = img_path + img_names[i]
    addimg = cv2.imread(add_path)
    add_h, add_w = addimg.shape[0], addimg.shape[1]
    if add_h != img_h or add_w != img_w:
        print('resize!')
        addimg = cv2.resize(addimg, (img_w, img_h), interpolation=cv2.INTER_LINEAR)
    scale_h, scale_w = img_h / add_h, img_w / add_w

    lam = np.random.beta(1.5, 1.5)
    print(lam)
    mixed_img = lam * img + (1 - lam) * addimg
    save_img = save_path + imgname[:-4] + '_3.jpg'
    cv2.imwrite(save_img, mixed_img)
    print(save_img)

    print(imgname, img_names[i])
    if imgname != img_names[i]:
        xmlfile1 = xml_path + imgname[:-4] + '.xml'
        xmlfile2 = xml_path + img_names[i][:-4] + '.xml'
        print(xmlfile1,xmlfile2)

        tree1 = ET.parse(xmlfile1)
        tree2 = ET.parse(xmlfile2)

        doc = xml.dom.minidom.Document()
        root = doc.createElement("annotation")
        doc.appendChild(root)


        for folds in tree1.findall("folder"):
            folder = doc.createElement("folder")
            folder.appendChild(doc.createTextNode(str(folds.text)))
            root.appendChild(folder)
        for filenames in tree1.findall("filename"):
            filename = doc.createElement("filename")
            filename.appendChild(doc.createTextNode(str(filenames.text)))
            root.appendChild(filename)
        for paths in tree1.findall("path"):
            path = doc.createElement("path")
            path.appendChild(doc.createTextNode(str(paths.text)))
            root.appendChild(path)
        for sources in tree1.findall("source"):
            source = doc.createElement("source")
            database = doc.createElement("database")
            database.appendChild(doc.createTextNode(str("Unknow")))
            source.appendChild(database)
            root.appendChild(source)
        for sizes in tree1.findall("size"):
            size = doc.createElement("size")
            width = doc.createElement("width")
            height = doc.createElement("height")
            depth = doc.createElement("depth")
            width.appendChild(doc.createTextNode(str(img_w)))
            height.appendChild(doc.createTextNode(str(img_h)))
            depth.appendChild(doc.createTextNode(str(3)))
            size.appendChild(width)
            size.appendChild(height)
            size.appendChild(depth)
            root.appendChild(size)

        nodeframe = doc.createElement("frame")
        nodeframe.appendChild(doc.createTextNode(imgname[:-4] + '_3'))

        objects = []

        for obj in tree1.findall("object"):
            obj_struct = {}
            obj_struct["name"] = obj.find("name").text
            obj_struct["pose"] = obj.find("pose").text
            obj_struct["truncated"] = obj.find("truncated").text
            obj_struct["difficult"] = obj.find("difficult").text
            bbox = obj.find("bndbox")
            obj_struct["bbox"] = [int(bbox.find("xmin").text),
                                  int(bbox.find("ymin").text),
                                  int(bbox.find("xmax").text),
                                  int(bbox.find("ymax").text)]
            objects.append(obj_struct)

        for obj in tree2.findall("object"):
            obj_struct = {}
            obj_struct["name"] = obj.find("name").text
            obj_struct["pose"] = obj.find("pose").text
            obj_struct["truncated"] = obj.find("truncated").text
            obj_struct["difficult"] = obj.find("difficult").text          # 有的版本的labelImg改参数为小写difficult
            bbox = obj.find("bndbox")
            obj_struct["bbox"] = [int(int(bbox.find("xmin").text) * scale_w),
                                  int(int(bbox.find("ymin").text) * scale_h),
                                  int(int(bbox.find("xmax").text) * scale_w),
                                  int(int(bbox.find("ymax").text) * scale_h)]
            objects.append(obj_struct)

        for obj in objects:
            nodeobject = doc.createElement("object")
            nodename = doc.createElement("name")
            nodepose = doc.createElement("pose")
            nodetruncated = doc.createElement("truncated")
            nodedifficult = doc.createElement("difficult")
            nodebndbox = doc.createElement("bndbox")
            nodexmin = doc.createElement("xmin")
            nodeymin = doc.createElement("ymin")
            nodexmax = doc.createElement("xmax")
            nodeymax = doc.createElement("ymax")
            nodename.appendChild(doc.createTextNode(obj["name"]))
            nodepose.appendChild(doc.createTextNode(obj["pose"]))
            nodepose.appendChild(doc.createTextNode(obj["truncated"]))
            nodedifficult.appendChild(doc.createTextNode(obj["difficult"]))
            nodexmin.appendChild(doc.createTextNode(str(obj["bbox"][0])))
            nodeymin.appendChild(doc.createTextNode(str(obj["bbox"][1])))
            nodexmax.appendChild(doc.createTextNode(str(obj["bbox"][2])))
            nodeymax.appendChild(doc.createTextNode(str(obj["bbox"][3])))

            nodebndbox.appendChild(nodexmin)
            nodebndbox.appendChild(nodeymin)
            nodebndbox.appendChild(nodexmax)
            nodebndbox.appendChild(nodeymax)

            nodeobject.appendChild(nodename)
            nodeobject.appendChild(nodepose)
            nodeobject.appendChild(nodetruncated)
            nodeobject.appendChild(nodedifficult)
            nodeobject.appendChild(nodebndbox)

            root.appendChild(nodeobject)

        fp = open(save_xml + imgname[:-4] + "_3.xml", "w")
        doc.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8")
        fp.close()

    else:
        xmlfile1 = xml_path + imgname[:-4] + '.xml'
        print(xmlfile1)
        tree1 = ET.parse(xmlfile1)

        doc = xml.dom.minidom.Document()
        root = doc.createElement("annotation")


        doc.appendChild(root)

        for folds in tree1.findall("folder"):
            folder=doc.createElement("folder")
            folder.appendChild(doc.createTextNode(str(folds.text)))
            root.appendChild(folder)
        for filenames in tree1.findall("filename"):
            filename=doc.createElement("filename")
            filename.appendChild(doc.createTextNode(str(filenames.text)))
            root.appendChild(filename)
        for paths in tree1.findall("path"):
            path = doc.createElement("path")
            path.appendChild(doc.createTextNode(str(paths.text)))
            root.appendChild(path)
        for sources in tree1.findall("source"):
            source = doc.createElement("source")
            database = doc.createElement("database")
            database.appendChild(doc.createTextNode(str("Unknow")))
            source.appendChild(database)
            root.appendChild(source)
        for sizes in tree1.findall("size"):
            size = doc.createElement("size")
            width = doc.createElement("width")
            height = doc.createElement("height")
            depth = doc.createElement("depth")
            width.appendChild(doc.createTextNode(str(img_w)))
            height.appendChild(doc.createTextNode(str(img_h)))
            depth.appendChild(doc.createTextNode(str(3)))
            size.appendChild(width)
            size.appendChild(height)
            size.appendChild(depth)
            root.appendChild(size)


        nodeframe = doc.createElement("frame")
        nodeframe.appendChild(doc.createTextNode(imgname[:-4] + '_3'))
        objects = []

        for obj in tree1.findall("object"):
            obj_struct = {}
            obj_struct["name"] = obj.find("name").text
            obj_struct["pose"] = obj.find("pose").text
            obj_struct["truncated"] = obj.find("truncated").text
            obj_struct["difficult"] = obj.find("difficult").text
            bbox = obj.find("bndbox")
            obj_struct["bbox"] = [int(bbox.find("xmin").text),
                                  int(bbox.find("ymin").text),
                                  int(bbox.find("xmax").text),
                                  int(bbox.find("ymax").text)]
            objects.append(obj_struct)

        for obj in objects:
            nodeobject = doc.createElement("object")
            nodename = doc.createElement("name")
            nodepose = doc.createElement("pose")
            nodetruncated = doc.createElement("truncated")
            nodedifficult = doc.createElement("difficult")
            nodebndbox = doc.createElement("bndbox")
            nodexmin = doc.createElement("xmin")
            nodeymin = doc.createElement("ymin")
            nodexmax = doc.createElement("xmax")
            nodeymax = doc.createElement("ymax")
            nodename.appendChild(doc.createTextNode(obj["name"]))
            nodepose.appendChild(doc.createTextNode(obj["pose"]))
            nodetruncated.appendChild(doc.createTextNode(obj["truncated"]))
            nodedifficult.appendChild(doc.createTextNode(obj["difficult"]))
            nodexmin.appendChild(doc.createTextNode(str(obj["bbox"][0])))
            nodeymin.appendChild(doc.createTextNode(str(obj["bbox"][1])))
            nodexmax.appendChild(doc.createTextNode(str(obj["bbox"][2])))
            nodeymax.appendChild(doc.createTextNode(str(obj["bbox"][3])))

            nodebndbox.appendChild(nodexmin)
            nodebndbox.appendChild(nodeymin)
            nodebndbox.appendChild(nodexmax)
            nodebndbox.appendChild(nodeymax)

            nodeobject.appendChild(nodename)
            nodeobject.appendChild(nodepose)
            nodeobject.appendChild(nodetruncated)
            nodeobject.appendChild(nodedifficult)
            nodeobject.appendChild(nodebndbox)

            root.appendChild(nodeobject)

        fp = open(save_xml + imgname[:-4] + "_3.xml", "w")
        doc.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8")
        fp.close()

如需完整代码可联系我。

### 关于目标检测中的数据增强方法 在计算机视觉领域,特别是针对目标检测任务,数据增强技术扮演着至关重要的角色。这些技术能够有效提升模型泛化能力,减少过拟合风险,并提高最终性能。 #### 常见的数据增强方式 1. **几何变换** 几何变换包括旋转、翻转以及缩放等操作。这类转换可以改变图片的空间布局而不影响其语义信息。例如,在不影响物体类别的情况下随机水平或垂直翻转图像[^2]。 2. **颜色空间调整** 颜色抖动(Color Jittering),即对RGB通道分别施加亮度、对比度、饱和度和平滑度的变化;还有灰度化处理,即将彩色图转化为单通道灰阶表示形式。这有助于模拟不同的光照条件和环境因素的影响。 3. **裁剪与填充** 裁剪是指从原始图像中选取子窗口作为新的样本;而填充则是指当执行某些几何变形后可能会超出边界,则需采用适当策略补充缺失部分。这两种手段共同作用可增加训练集多样性并防止特定模式过度适应。 4. **混合样例生成** Mixup 和 Cutmix 是两种流行的混合样例生成算法。前者通过对两个不同类别的实例按一定比例线性组合创建新样本;后者则是在原图上挖去一块矩形区域再粘贴另一张同类标签下的片段。此类做法不仅扩充了可用素材数量还促进了跨类别间关系的学习。 5. **噪声注入** 向输入加入轻微扰动如高斯白噪或者椒盐噪音,以此测试网络鲁棒性和抗干扰特性。值得注意的是,适量的噪声可以帮助改善模型表现但过多反而会损害效果。 6. **多尺度训练** 使用图像金字塔结构将同一幅源片映射成多个尺寸版本参与迭代更新过程。对于大型物体而言更适合较大视野范围内的表达,反之亦然。这种方法特别适用于跨越多种规模的目标定位场合[^4]。 7. **Cutout** 这种简单却有效的技巧涉及遮挡掉一部分像素值形成黑色方块状空白区。尽管看似粗暴但却能显著降低记忆型偏差从而促使更深层次特征提取机制发挥作用。 8. **Mosaic 数据增强** Mosaic 将四张不同照片拼接在一起构成一张复合版面供后续分析利用。相比传统单一视角呈现方式更能体现全局上下文关联性,尤其适合街景监控等领域应用需求。 9. **AutoAugment/Random Erasing** AutoAugment 自动搜索最优参数配置方案以指导具体实施细节;Random Erasing 则是随机抹除若干连续位置上的数值达到类似目的。两者均体现了智能化程度较高的自动化探索思路。 ```python import torchvision.transforms as transforms transform = transforms.Compose([ transforms.RandomHorizontalFlip(), # 水平翻转 transforms.ColorJitter(brightness=0.2, contrast=0.2), # 随机色彩变化 transforms.Resize((224, 224)), # 改变大小至固定宽高 ]) ```
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值