pytorch一致数据增强—独用增强

前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose 类和 to_multi 包装函数。不过 [1] 没考虑各图用不同 augmentation 的情况,如:

  1. ColorJitter 只对 image 做,而不对 label 做;
  2. image 的 resize interpolation 可任选,但 label 只能用 nearest

本篇更新写法,支持各图同用、独用 augmentation。

Code

  • 对比 [1],主要改变是改写 MultiCompose 类,并将 to_multi 吸收入内。
  • MultiCompose 的用法还是和 torchvision.transforms.Compose 几乎一致,不过支持独用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
	"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
	raise NotImplementedError


class MultiCompose:
    """扩展 torchvision.transforms.Compose:支持输入多图,
    且保证各 augmentation 中所有输入都用同一随机状态(如旋转同一随机角度),
    分割任务有用。
    """

    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    MIN_SEED = 0 # - 0x8000_0000_0000_0000
    MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)

    def __init__(self, transforms):
    	"""输入:一个 list/tuple,
    	其中每个元素可以是一个 augmentation 对象(transform)/函数,各输入同用;
    	或一个嵌套的 list/tuple,为每个输入指定独用的 augmentation。
    	"""
        # self.transforms = [to_multi(t) for t in transforms]
        no_op = lambda x: x # i.e. identity function
        self.transforms = []
        for t in transforms:
            if isinstance(t, (tuple, list)):
            	# convert `None` to `no_op` for convenience
                self.transforms.append([no_op if _t is None else _t for _t in t])
            else:
                self.transforms.append(t)

    def __call__(self, *images):
        for t in self.transforms:
            if isinstance(t, (tuple, list)): # 独用
                assert len(images) <= len(t) # allow redundant transform
            else: # 同用
                t = [t] * len(images)

            _aug_images = []
            _seed = random.randint(self.MIN_SEED, self.MAX_SEED)
            for _im, _t in zip(images, t):
                seed_everything(_seed)
                _aug_images.append(_t(_im))

            images = _aug_images

        if len(images) == 1:
            images = images[0]
        return images

Usage & Test

例程沿用 [1],但改一下 augmentation:

train_trans = MultiCompose([
	# image 用 bilinear,label 用 nearest
    (ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 独用
    transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
    transforms.RandomHorizontalFlip(), # 同用
    # ColorJitter 只对 image 做,label 不做(None)
    [transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 独用
])
  • 效果:

test-dataset-1.png

Supporting Multiple Input Styles

(2024.3.13)前文的 MultiCompose 只支持顺序,如果有很多个输入,则用 dict 通过 key 分辨各输入更方便。故改写之以支持顺序输入和 dict 输入两种模式。另支持指定 seed 以保证复现。

class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple inputs
    and ensures the same random seed is applied on each of these inputs at each transforms.
    This can be useful when simultaneously transforming images & segmentation masks.

    Usage:
        ```python
        ## 1. compatible with single input (just like torchvision.transforms.Compose)
        trfm = MultiCompose([
            transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
            transforms.RandomAffine(30, (0.1, 0.1)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.1, 0.2, 0.3, 0.4)
        ])
        aug_images = trfm(images)

        ## 2. sequential style
        seq_trfm = MultiCompose([
            # interpolation: image uses `bilinear`, label uses `nearest`
            [transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
             transforms.Resize((224, 256), transforms.InterpolationMode.NEAREST)],
            transforms.RandomAffine(30, (0.1, 0.1)),
            transforms.RandomHorizontalFlip(),
            # apply `ColorJitter` on image but not on label (thus `None`)
            (transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None),
        ])
        # apply augmentations on both `images` and `seg_labels`
        aug_images, aug_seg_labels = seq_trfm(images, seg_labels)

        ## 3. dict style
        dict_trfm = MultiCompose([
            # interpolation: image uses `bilinear`, label uses `nearest`
            {"image": transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
             "label": transforms.Resize((224, 256), transforms.InterpolationMode.NEAREST)},
            transforms.RandomAffine(30, (0.1, 0.1)),
            transforms.RandomHorizontalFlip(),
            # apply `ColorJitter` on image but not on label (lack here)
            {"image": transforms.ColorJitter(0.1, 0.2, 0.3, 0.4)},
        ])
        # apply augmentations on both `images` and `seg_labels`
        res = dict_trfm({"image": images, "label": seg_labels})
        aug_images = res["image"]
        aug_seg_labels = res["label"]
        ```
    """

    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    MIN_SEED = 0 # - 0x8000_0000_0000_0000
    MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)

    def __init__(self, transforms, seed=None):
        """
        transforms: list/tuple of:
            - transform object (for all inputs)
            - embedded list/tuple/dict of transform objects (for each input)
        seed: int, always use this seed if provided (deterministic for reproducibility)
        """
        self.transforms = transforms
        self.seed = seed

    def append(self, t):
        self.transforms.append(t)

    def extend(self, ts):
        assert isinstance(ts, (tuple, list))
        for t in ts:
            self.append(t)

    def call_sequential(self, *images):
        for t in self.transforms:
            if isinstance(t, (tuple, list)):
                # `<=` allows redundant transforms
                assert len(images) <= len(t), f"#inputs: {len(images)} v.s. #transforms: {len(self.transforms)}"
            else:
                t = [t] * len(images)

            _aug_images = []
            _seed = random.randint(MultiCompose.MIN_SEED, MultiCompose.MAX_SEED) \
            		if self.seed is None else self.seed
            for _im, _t in zip(images, t):
                seed_everything(_seed)
                _aug_images.append(_im if _t is None else _t(_im))

            images = _aug_images

        if len(images) == 1:
            images = images[0]
        return images

    def call_dict(self, images):
        for t in self.transforms:
            if not isinstance(t, dict):
                t = {k: t for k in images}

            _aug_images = {}
            _seed = random.randint(MultiCompose.MIN_SEED, MultiCompose.MAX_SEED) \
            		if self.seed is None else self.seed
            for k in images:
                seed_everything(_seed)
                _aug_images[k] = t[k](images[k]) if k in t and t[k] is not None else images[k]

            images = _aug_images

        return images

    def __call__(self, *images):
        if isinstance(images[0], dict):
            assert len(images) == 1
            return self.call_dict(images[0])
        else:
            return self.call_sequential(*images)

示例用法:

print("1. 单个输入 (兼容 torchvision.transforms.Compose)")
trfm = MultiCompose([
    transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
    transforms.RandomAffine(30, (0.1, 0.1)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1, 0.2, 0.3, 0.4)
])
aug_images = trfm(images)

print("2. 顺序输入")
seq_trfm = MultiCompose([
    # interpolation: image uses `bilinear`, label uses `nearest`
    [transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
     transforms.Resize((224, 256), transforms.InterpolationMode.NEAREST)],
    transforms.RandomAffine(30, (0.1, 0.1)),
    transforms.RandomHorizontalFlip(),
    # apply `ColorJitter` on image but not on label (thus `None`)
    (transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None),
])
aug_images, aug_seg_labels = seq_trfm(images, seg_labels)

print("3. dict 输入")
dict_trfm = MultiCompose([
    # interpolation: image uses `bilinear`, label uses `nearest`
    {"image": transforms.Resize((224, 256), transforms.InterpolationMode.BILINEAR),
     "label": transforms.Resize((224, 256), transforms.InterpolationMode.NEAREST)},
    transforms.RandomAffine(30, (0.1, 0.1)),
    transforms.RandomHorizontalFlip(),
    # apply `ColorJitter` on image but not on label (lack here)
    {"image": transforms.ColorJitter(0.1, 0.2, 0.3, 0.4)},
])
# 返回也是 dict,同样的 keys
res = dict_trfm({"image": images, "label": seg_labels})
aug_images = res["image"]
aug_seg_labels = res["label"]

References

  1. pytorch一致数据增强
  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值