pytorch torchvision 批量图像 统一变换

问题背景:

有的时候,处理序列数据需要使用 同一种 torchvison.transforms, 比如视频,这一组视频帧的数据增强方式需要一致。


问题描述及问题成因

在使用torchvision的时候大多数现成的 变换类 的 源码如下:

@staticmethod
    def get_params(degrees: List[float]) -> float:
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
            float: angle parameter to be passed to ``rotate`` for random rotation.
        """
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
        return angle

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be rotated.

        Returns:
            PIL Image or Tensor: Rotated image.
        """
        fill = self.fill
        channels, _, _ = F.get_dimensions(img)
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * channels
            else:
                fill = [float(f) for f in fill]
        angle = self.get_params(self.degrees)

        return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)

以上所示为 RnadomRotate 类 的代码,可见每次调用 forward 时都会执行一次self.get_params()这就导致了每张图像的 angle 旋转角度都不一致,其他变换同理


解决过程:

那既然要 使 同一批图像的变换是一致的,则可方便的修改代码为例:

class RandomRotation_once(transforms.RandomRotation):
    def __init__(self, degrees, interpolation=F.InterpolationMode.NEAREST, expand=False, center=None, fill=0):
        super().__init__(degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
        self.angle = self.get_params(degrees)
    def forward(self, img):
        fill = self.fill
        channels, _, _ = F.get_dimensions(img)
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * channels
            else:
                fill = [float(f) for f in fill]
        return F.rotate(img, self.angle, self.interpolation, self.expand, self.center, fill)

class ColorJitter_once(transforms.ColorJitter):
    def __init__(self,brightness=0,contrast=0,saturation=0,hue=0,p=0):
        super().__init__(brightness,contrast,saturation,hue)
        self.fn_idx, self.brightness_factor, self.contrast_factor, \
            self.saturation_factor, self.hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue)
        self.propability = torch.rand(1)
        self.p = p
    def forward(self, img):
        if self.propability < self.p:
            for fn_id in self.fn_idx:
                if fn_id == 0 and self.brightness_factor is not None:
                    img = F.adjust_brightness(img, self.brightness_factor)
                elif fn_id == 1 and self.contrast_factor is not None:
                    img = F.adjust_contrast(img, self.contrast_factor)
                elif fn_id == 2 and self.saturation_factor is not None:
                    img = F.adjust_saturation(img, self.saturation_factor)
                elif fn_id == 3 and self.hue_factor is not None:
                    img = F.adjust_hue(img, self.hue_factor)

        return img

 重点是将:

import torchvision.transforms.functional as F

添加在代码前。原因见:

解决 :Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant-CSDN博客


解决方案结论及重点:

该方案继承原来的类别,并将 angle 的设定放在 init 里,使其 同一实例的前向过程的参数是一致的。而这样甚至可以在 进行 colorjitter的时候,把进行变换的概率进行添加。


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值