Copy-Paste(CVPR 2021)原理与代码解析

paper:Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation

official implementation:https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/copy_paste

third-party implementation:

https://github.com/open-mmlab/mmdetection/tree/main/configs/simple_copy_paste

概述

本文针对实例分割提出了一种新的数据增强方法Copy-Paste,可以显著增强实例分割的精度。尽管scale jittering和random resizing等数据增强方法已经用于之前的实例分割模型中,但它们本质上是通用的方法,并不是专门为实例分割设计的。

与其它方法的区别

Copy-Paste与MixUp和CutMix有点像,但只提取一个对象对应的精确像素而不是边界框中所有的像素。

与Contextual Copy-Paste和InstaBoost相比,本文方法的一个关键区别在于不需要对周围的视觉上下文进行建模来放置被复制的对象实例。简单的随机位置放置就可以得到明显的改进。

方法介绍

本文利用复制粘贴生成新数据的方法非常简单。随机选择两张图片,并对每张图片进行随机尺度抖动和随机水平翻转,然后从一张图像中随机挑选所有对象的一个子集,并将它们粘贴到另一张图像上。最后相应的调整ground-truth标签,包括删除完全遮挡的对象、更新部分遮挡对象的masks和bounding boxes。

为了将新的对象融合到一张图片中,首先使用gt标签计算待粘贴对象的binary mask \(\alpha\),然后按 \(I_{1}\times \alpha+I_{2}\times (1-\alpha)\) 得到新生成的图像,其中 \(I_{1}\) 是用于提取粘贴对象的图像,\(I_{2}\) 是粘贴的图像。为了平滑粘贴对象的边缘,本文对 \(\alpha\) 进行了高斯滤波,但作者发现,不用任何平滑操作也能得到相似的性能。

本文使用了两种不同的增强方法与Copy-Paste结合使用,standard scale jittering(SSJ)和large scale jittering(LSJ)。图3展示了这两种方法的区别,作者通过实验发现,与之前大多数工作中使用的标准尺度抖动SSJ相比,大尺度抖动LSJ与Copy-Paste结合使用得到了显著的性能提升。

本文还研究了将Copy-Paste作为一种合并额外未标注图像的方法,self-training Copy-Paste的过程如下:(1)在标注数据上用Copy-Paste训练一个监督模型(2)在未标注数据上生成伪标签(3)将gt实例粘贴到伪标签图像和有标签图像上并在这些新数据上训练一个新模型

代码解析

mmdetection中Copy-Paste的完整实现如下 

class CopyPaste(BaseTransform):
    """Simple Copy-Paste is a Strong Data Augmentation Method for Instance
    Segmentation The simple copy-paste transform steps are as follows:

    1. The destination image is already resized with aspect ratio kept,
       cropped and padded.
    2. Randomly select a source image, which is also already resized
       with aspect ratio kept, cropped and padded in a similar way
       as the destination image.
    3. Randomly select some objects from the source image.
    4. Paste these source objects to the destination image directly,
       due to the source and destination image have the same size.
    5. Update object masks of the destination image, for some origin objects
       may be occluded.
    6. Generate bboxes from the updated destination masks and
       filter some objects which are totally occluded, and adjust bboxes
       which are partly occluded.
    7. Append selected source bboxes, masks, and labels.

    Required Keys:

    - img
    - gt_bboxes (BaseBoxes[torch.float32]) (optional)
    - gt_bboxes_labels (np.int64) (optional)
    - gt_ignore_flags (bool) (optional)
    - gt_masks (BitmapMasks) (optional)

    Modified Keys:

    - img
    - gt_bboxes (optional)
    - gt_bboxes_labels (optional)
    - gt_ignore_flags (optional)
    - gt_masks (optional)

    Args:
        max_num_pasted (int): The maximum number of pasted objects.
            Defaults to 100.
        bbox_occluded_thr (int): The threshold of occluded bbox.
            Defaults to 10.
        mask_occluded_thr (int): The threshold of occluded mask.
            Defaults to 300.
        selected (bool): Whether select objects or not. If select is False,
            all objects of the source image will be pasted to the
            destination image.
            Defaults to True.
    """

    def __init__(
        self,
        max_num_pasted: int = 100,
        bbox_occluded_thr: int = 10,
        mask_occluded_thr: int = 300,
        selected: bool = True,
    ) -> None:
        self.max_num_pasted = max_num_pasted
        self.bbox_occluded_thr = bbox_occluded_thr
        self.mask_occluded_thr = mask_occluded_thr
        self.selected = selected

    @cache_randomness
    def get_indexes(self, dataset: BaseDataset) -> int:
        """Call function to collect indexes.s.

        Args:
            dataset (:obj:`MultiImageMixDataset`): The dataset.
        Returns:
            list: Indexes.
        """
        return random.randint(0, len(dataset))

    @autocast_box_type()
    def transform(self, results: dict) -> dict:
        """Transform function to make a copy-paste of image.

        Args:
            results (dict): Result dict.
        Returns:
            dict: Result dict with copy-paste transformed.
        """

        assert 'mix_results' in results
        num_images = len(results['mix_results'])
        assert num_images == 1, \
            f'CopyPaste only supports processing 2 images, got {num_images}'
        if self.selected:
            selected_results = self._select_object(results['mix_results'][0])
        else:
            selected_results = results['mix_results'][0]
        return self._copy_paste(results, selected_results)

    @cache_randomness
    def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
        max_num_pasted = min(num_bboxes + 1, self.max_num_pasted)
        num_pasted = np.random.randint(0, max_num_pasted)
        return np.random.choice(num_bboxes, size=num_pasted, replace=False)

    def _select_object(self, results: dict) -> dict:
        """Select some objects from the source results."""
        bboxes = results['gt_bboxes']
        labels = results['gt_bboxes_labels']
        masks = results['gt_masks']
        ignore_flags = results['gt_ignore_flags']

        selected_inds = self._get_selected_inds(bboxes.shape[0])

        selected_bboxes = bboxes[selected_inds]
        selected_labels = labels[selected_inds]
        selected_masks = masks[selected_inds]
        selected_ignore_flags = ignore_flags[selected_inds]

        results['gt_bboxes'] = selected_bboxes
        results['gt_bboxes_labels'] = selected_labels
        results['gt_masks'] = selected_masks
        results['gt_ignore_flags'] = selected_ignore_flags
        return results

    def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
        """CopyPaste transform function.

        Args:
            dst_results (dict): Result dict of the destination image.
            src_results (dict): Result dict of the source image.
        Returns:
            dict: Updated result dict.
        """
        dst_img = dst_results['img']
        dst_bboxes = dst_results['gt_bboxes']
        dst_labels = dst_results['gt_bboxes_labels']
        dst_masks = dst_results['gt_masks']
        dst_ignore_flags = dst_results['gt_ignore_flags']

        src_img = src_results['img']
        src_bboxes = src_results['gt_bboxes']
        src_labels = src_results['gt_bboxes_labels']
        src_masks = src_results['gt_masks']
        src_ignore_flags = src_results['gt_ignore_flags']

        if len(src_bboxes) == 0:
            return dst_results

        # update masks and generate bboxes from updated masks
        composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
        updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)
        updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes))
        assert len(updated_dst_bboxes) == len(updated_dst_masks)

        # filter totally occluded objects
        l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
        bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
            dim=-1).numpy()
        masks_inds = updated_dst_masks.masks.sum(
            axis=(1, 2)) > self.mask_occluded_thr
        valid_inds = bboxes_inds | masks_inds

        # Paste source objects to destination image directly
        img = dst_img * (1 - composed_mask[..., np.newaxis]
                         ) + src_img * composed_mask[..., np.newaxis]
        bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
        labels = np.concatenate([dst_labels[valid_inds], src_labels])
        masks = np.concatenate(
            [updated_dst_masks.masks[valid_inds], src_masks.masks])
        ignore_flags = np.concatenate(
            [dst_ignore_flags[valid_inds], src_ignore_flags])

        dst_results['img'] = img
        dst_results['gt_bboxes'] = bboxes
        dst_results['gt_bboxes_labels'] = labels
        dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
                                              masks.shape[2])
        dst_results['gt_ignore_flags'] = ignore_flags

        return dst_results

    def _get_updated_masks(self, masks: BitmapMasks,
                           composed_mask: np.ndarray) -> BitmapMasks:
        """Update masks with composed mask."""
        assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
            'Cannot compare two arrays of different size'
        masks.masks = np.where(composed_mask, 0, masks.masks)
        return masks

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(max_num_pasted={self.max_num_pasted}, '
        repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
        repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
        repr_str += f'selected={self.selected})'
        return repr_str

其中transform是调用接口。入参results包含了两张图像的相关信息,包括图像、mask标签、box标签等。对于多张图像结合的数据增强,除了一张主图像,其它的图像都存在result['mix_results']中,这里Copy-Paste只针对两张图像,因此mix_results中只有一张图像的信息。对于4张图片的mosaic增强,mix_results中包含另外3张图片的信息。

函数_select_object用于提取需要粘贴的对象,对于一张图像中的所有实例,我们随机提取一个对象子集。

函数_copy_paste进行粘贴、过滤遮挡过多的粘贴对象、更新标签等操作。

在mmdet实例分割的mask表示中,即这里的sec_masks.masks,一个通道代表一个实例,不考虑具体类别,下面的操作将待粘贴的所有对象的mask放到一个单通道的mask图中。

composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)

然后self._get_updated_masks的入参为待粘贴的掩码composed_mask和目标图像中的掩码dst_masks,返回的是composed_mask没有遮挡住dst_masks的部分。

updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)

def _get_updated_masks(self, masks: BitmapMasks,
                       composed_mask: np.ndarray) -> BitmapMasks:
    """Update masks with composed mask."""
    assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
        'Cannot compare two arrays of different size'
    masks.masks = np.where(composed_mask, 0, masks.masks)  # composed_mask是单通道,masks.masks可能是多通道
    # 结果可能是多通道,白的地方是composed_mask没有盖住原始masks.masks的部分
    return masks

接下来根据self.bbox_occluded_thr和self.mask_occluded_thr过滤不符合条件的实例。l1_distance是目标图上原始掩码的box和粘贴后未遮挡部分的box的左上、右下xy坐标的差。距离越大说明遮挡部分越多。masks_inds是根据未遮挡部分掩码的像素个数判断的。

 # filter totally occluded objects
l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
    dim=-1).numpy()
masks_inds = updated_dst_masks.masks.sum(
    axis=(1, 2)) > self.mask_occluded_thr
valid_inds = bboxes_inds | masks_inds

最后按式 \(I_{1}\times \alpha+I_{2}\times (1-\alpha)\) 更新图片,以及对应的box、mask、label等

# Paste source objects to destination image directly
img = dst_img * (1 - composed_mask[..., np.newaxis]
                 ) + src_img * composed_mask[..., np.newaxis]
bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
labels = np.concatenate([dst_labels[valid_inds], src_labels])
masks = np.concatenate(
    [updated_dst_masks.masks[valid_inds], src_masks.masks])
ignore_flags = np.concatenate(
    [dst_ignore_flags[valid_inds], src_ignore_flags])

  • 13
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
以下是一个可能的实现: ```python import cv2 import numpy as np import random def copy_paste(img): # 选择要复制的区域 h, w, _ = img.shape x1 = random.randint(0, w // 2) y1 = random.randint(0, h // 2) x2 = random.randint(w // 2, w - 1) y2 = random.randint(h // 2, h - 1) # 复制区域 copy_region = img[y1:y2, x1:x2].copy() # 选择要粘贴的区域 paste_w = random.randint(x2 - x1, w - 1 - x2) paste_h = random.randint(y2 - y1, h - 1 - y2) paste_x = random.randint(0, w - paste_w - 1) paste_y = random.randint(0, h - paste_h - 1) # 粘贴区域 img[paste_y:paste_y+paste_h, paste_x:paste_x+paste_w] = copy_region return img ``` 此代码实现了以下步骤: 1. 从图像中随机选择一个矩形区域进行复制。 2. 从图像中随机选择一个位置和大小的矩形区域进行粘贴。 3. 将复制的区域粘贴到粘贴的区域中。 4. 返回新的图像。 这个实现假定输入图像是一个numpy数组,其中通道顺序为BGR。为了在训练过程中应用这个数据增强,可以将它包装在一个函数中并将其作为参数传递给数据生成器。例如: ```python def data_generator(images, batch_size, augment_fn=None): while True: batch_indices = np.random.choice(len(images), batch_size) batch_images = [cv2.imread(images[i]) for i in batch_indices] if augment_fn is not None: batch_images = [augment_fn(img) for img in batch_images] batch_images = np.array(batch_images) yield batch_images ``` 这个生成器从一组文件名中随机选择一批图像,并可以选择应用一个数据增强函数。要使用copy_paste数据增强,可以在调用data_generator时传递它作为augment_fn参数: ```python generator = data_generator(images, batch_size=32, augment_fn=copy_paste) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值