paper:Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation
official implementation:https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/copy_paste
third-party implementation:
https://github.com/open-mmlab/mmdetection/tree/main/configs/simple_copy_paste
概述
本文针对实例分割提出了一种新的数据增强方法Copy-Paste,可以显著增强实例分割的精度。尽管scale jittering和random resizing等数据增强方法已经用于之前的实例分割模型中,但它们本质上是通用的方法,并不是专门为实例分割设计的。
与其它方法的区别
Copy-Paste与MixUp和CutMix有点像,但只提取一个对象对应的精确像素而不是边界框中所有的像素。
与Contextual Copy-Paste和InstaBoost相比,本文方法的一个关键区别在于不需要对周围的视觉上下文进行建模来放置被复制的对象实例。简单的随机位置放置就可以得到明显的改进。
方法介绍
本文利用复制粘贴生成新数据的方法非常简单。随机选择两张图片,并对每张图片进行随机尺度抖动和随机水平翻转,然后从一张图像中随机挑选所有对象的一个子集,并将它们粘贴到另一张图像上。最后相应的调整ground-truth标签,包括删除完全遮挡的对象、更新部分遮挡对象的masks和bounding boxes。
为了将新的对象融合到一张图片中,首先使用gt标签计算待粘贴对象的binary mask \(\alpha\),然后按 \(I_{1}\times \alpha+I_{2}\times (1-\alpha)\) 得到新生成的图像,其中 \(I_{1}\) 是用于提取粘贴对象的图像,\(I_{2}\) 是粘贴的图像。为了平滑粘贴对象的边缘,本文对 \(\alpha\) 进行了高斯滤波,但作者发现,不用任何平滑操作也能得到相似的性能。
本文使用了两种不同的增强方法与Copy-Paste结合使用,standard scale jittering(SSJ)和large scale jittering(LSJ)。图3展示了这两种方法的区别,作者通过实验发现,与之前大多数工作中使用的标准尺度抖动SSJ相比,大尺度抖动LSJ与Copy-Paste结合使用得到了显著的性能提升。
本文还研究了将Copy-Paste作为一种合并额外未标注图像的方法,self-training Copy-Paste的过程如下:(1)在标注数据上用Copy-Paste训练一个监督模型(2)在未标注数据上生成伪标签(3)将gt实例粘贴到伪标签图像和有标签图像上并在这些新数据上训练一个新模型
代码解析
mmdetection中Copy-Paste的完整实现如下
class CopyPaste(BaseTransform):
"""Simple Copy-Paste is a Strong Data Augmentation Method for Instance
Segmentation The simple copy-paste transform steps are as follows:
1. The destination image is already resized with aspect ratio kept,
cropped and padded.
2. Randomly select a source image, which is also already resized
with aspect ratio kept, cropped and padded in a similar way
as the destination image.
3. Randomly select some objects from the source image.
4. Paste these source objects to the destination image directly,
due to the source and destination image have the same size.
5. Update object masks of the destination image, for some origin objects
may be occluded.
6. Generate bboxes from the updated destination masks and
filter some objects which are totally occluded, and adjust bboxes
which are partly occluded.
7. Append selected source bboxes, masks, and labels.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- gt_masks (BitmapMasks) (optional)
Modified Keys:
- img
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
- gt_masks (optional)
Args:
max_num_pasted (int): The maximum number of pasted objects.
Defaults to 100.
bbox_occluded_thr (int): The threshold of occluded bbox.
Defaults to 10.
mask_occluded_thr (int): The threshold of occluded mask.
Defaults to 300.
selected (bool): Whether select objects or not. If select is False,
all objects of the source image will be pasted to the
destination image.
Defaults to True.
"""
def __init__(
self,
max_num_pasted: int = 100,
bbox_occluded_thr: int = 10,
mask_occluded_thr: int = 300,
selected: bool = True,
) -> None:
self.max_num_pasted = max_num_pasted
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
"""Call function to collect indexes.s.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: Indexes.
"""
return random.randint(0, len(dataset))
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to make a copy-paste of image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with copy-paste transformed.
"""
assert 'mix_results' in results
num_images = len(results['mix_results'])
assert num_images == 1, \
f'CopyPaste only supports processing 2 images, got {num_images}'
if self.selected:
selected_results = self._select_object(results['mix_results'][0])
else:
selected_results = results['mix_results'][0]
return self._copy_paste(results, selected_results)
@cache_randomness
def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
max_num_pasted = min(num_bboxes + 1, self.max_num_pasted)
num_pasted = np.random.randint(0, max_num_pasted)
return np.random.choice(num_bboxes, size=num_pasted, replace=False)
def _select_object(self, results: dict) -> dict:
"""Select some objects from the source results."""
bboxes = results['gt_bboxes']
labels = results['gt_bboxes_labels']
masks = results['gt_masks']
ignore_flags = results['gt_ignore_flags']
selected_inds = self._get_selected_inds(bboxes.shape[0])
selected_bboxes = bboxes[selected_inds]
selected_labels = labels[selected_inds]
selected_masks = masks[selected_inds]
selected_ignore_flags = ignore_flags[selected_inds]
results['gt_bboxes'] = selected_bboxes
results['gt_bboxes_labels'] = selected_labels
results['gt_masks'] = selected_masks
results['gt_ignore_flags'] = selected_ignore_flags
return results
def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
"""CopyPaste transform function.
Args:
dst_results (dict): Result dict of the destination image.
src_results (dict): Result dict of the source image.
Returns:
dict: Updated result dict.
"""
dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_bboxes_labels']
dst_masks = dst_results['gt_masks']
dst_ignore_flags = dst_results['gt_ignore_flags']
src_img = src_results['img']
src_bboxes = src_results['gt_bboxes']
src_labels = src_results['gt_bboxes_labels']
src_masks = src_results['gt_masks']
src_ignore_flags = src_results['gt_ignore_flags']
if len(src_bboxes) == 0:
return dst_results
# update masks and generate bboxes from updated masks
composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)
updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes))
assert len(updated_dst_bboxes) == len(updated_dst_masks)
# filter totally occluded objects
l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
dim=-1).numpy()
masks_inds = updated_dst_masks.masks.sum(
axis=(1, 2)) > self.mask_occluded_thr
valid_inds = bboxes_inds | masks_inds
# Paste source objects to destination image directly
img = dst_img * (1 - composed_mask[..., np.newaxis]
) + src_img * composed_mask[..., np.newaxis]
bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
labels = np.concatenate([dst_labels[valid_inds], src_labels])
masks = np.concatenate(
[updated_dst_masks.masks[valid_inds], src_masks.masks])
ignore_flags = np.concatenate(
[dst_ignore_flags[valid_inds], src_ignore_flags])
dst_results['img'] = img
dst_results['gt_bboxes'] = bboxes
dst_results['gt_bboxes_labels'] = labels
dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
masks.shape[2])
dst_results['gt_ignore_flags'] = ignore_flags
return dst_results
def _get_updated_masks(self, masks: BitmapMasks,
composed_mask: np.ndarray) -> BitmapMasks:
"""Update masks with composed mask."""
assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
'Cannot compare two arrays of different size'
masks.masks = np.where(composed_mask, 0, masks.masks)
return masks
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(max_num_pasted={self.max_num_pasted}, '
repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
repr_str += f'selected={self.selected})'
return repr_str
其中transform是调用接口。入参results包含了两张图像的相关信息,包括图像、mask标签、box标签等。对于多张图像结合的数据增强,除了一张主图像,其它的图像都存在result['mix_results']中,这里Copy-Paste只针对两张图像,因此mix_results中只有一张图像的信息。对于4张图片的mosaic增强,mix_results中包含另外3张图片的信息。
函数_select_object用于提取需要粘贴的对象,对于一张图像中的所有实例,我们随机提取一个对象子集。
函数_copy_paste进行粘贴、过滤遮挡过多的粘贴对象、更新标签等操作。
在mmdet实例分割的mask表示中,即这里的sec_masks.masks,一个通道代表一个实例,不考虑具体类别,下面的操作将待粘贴的所有对象的mask放到一个单通道的mask图中。
composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
然后self._get_updated_masks的入参为待粘贴的掩码composed_mask和目标图像中的掩码dst_masks,返回的是composed_mask没有遮挡住dst_masks的部分。
updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)
def _get_updated_masks(self, masks: BitmapMasks,
composed_mask: np.ndarray) -> BitmapMasks:
"""Update masks with composed mask."""
assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
'Cannot compare two arrays of different size'
masks.masks = np.where(composed_mask, 0, masks.masks) # composed_mask是单通道,masks.masks可能是多通道
# 结果可能是多通道,白的地方是composed_mask没有盖住原始masks.masks的部分
return masks
接下来根据self.bbox_occluded_thr和self.mask_occluded_thr过滤不符合条件的实例。l1_distance是目标图上原始掩码的box和粘贴后未遮挡部分的box的左上、右下xy坐标的差。距离越大说明遮挡部分越多。masks_inds是根据未遮挡部分掩码的像素个数判断的。
# filter totally occluded objects
l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
dim=-1).numpy()
masks_inds = updated_dst_masks.masks.sum(
axis=(1, 2)) > self.mask_occluded_thr
valid_inds = bboxes_inds | masks_inds
最后按式 \(I_{1}\times \alpha+I_{2}\times (1-\alpha)\) 更新图片,以及对应的box、mask、label等
# Paste source objects to destination image directly
img = dst_img * (1 - composed_mask[..., np.newaxis]
) + src_img * composed_mask[..., np.newaxis]
bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
labels = np.concatenate([dst_labels[valid_inds], src_labels])
masks = np.concatenate(
[updated_dst_masks.masks[valid_inds], src_masks.masks])
ignore_flags = np.concatenate(
[dst_ignore_flags[valid_inds], src_ignore_flags])