问题背景:
有的时候,处理序列数据需要使用 同一种 torchvison.transforms, 比如视频,这一组视频帧的数据增强方式需要一致。
问题描述及问题成因
在使用torchvision的时候大多数现成的 变换类 的 源码如下:
@staticmethod
def get_params(degrees: List[float]) -> float:
"""Get parameters for ``rotate`` for a random rotation.
Returns:
float: angle parameter to be passed to ``rotate`` for random rotation.
"""
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
return angle
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be rotated.
Returns:
PIL Image or Tensor: Rotated image.
"""
fill = self.fill
channels, _, _ = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
以上所示为 RnadomRotate 类 的代码,可见每次调用 forward 时都会执行一次self.get_params()这就导致了每张图像的 angle 旋转角度都不一致,其他变换同理
解决过程:
那既然要 使 同一批图像的变换是一致的,则可方便的修改代码为例:
class RandomRotation_once(transforms.RandomRotation):
def __init__(self, degrees, interpolation=F.InterpolationMode.NEAREST, expand=False, center=None, fill=0):
super().__init__(degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
self.angle = self.get_params(degrees)
def forward(self, img):
fill = self.fill
channels, _, _ = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
return F.rotate(img, self.angle, self.interpolation, self.expand, self.center, fill)
class ColorJitter_once(transforms.ColorJitter):
def __init__(self,brightness=0,contrast=0,saturation=0,hue=0,p=0):
super().__init__(brightness,contrast,saturation,hue)
self.fn_idx, self.brightness_factor, self.contrast_factor, \
self.saturation_factor, self.hue_factor = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
self.propability = torch.rand(1)
self.p = p
def forward(self, img):
if self.propability < self.p:
for fn_id in self.fn_idx:
if fn_id == 0 and self.brightness_factor is not None:
img = F.adjust_brightness(img, self.brightness_factor)
elif fn_id == 1 and self.contrast_factor is not None:
img = F.adjust_contrast(img, self.contrast_factor)
elif fn_id == 2 and self.saturation_factor is not None:
img = F.adjust_saturation(img, self.saturation_factor)
elif fn_id == 3 and self.hue_factor is not None:
img = F.adjust_hue(img, self.hue_factor)
return img
重点是将:
import torchvision.transforms.functional as F
添加在代码前。原因见:
解决方案结论及重点:
该方案继承原来的类别,并将 angle 的设定放在 init 里,使其 同一实例的前向过程的参数是一致的。而这样甚至可以在 进行 colorjitter的时候,把进行变换的概率进行添加。