代码地址:https://github.com/POSTECH-CVLab/point-transformer/blob/10d43ab5210fc93ffa15886f2a4c6460cc308780/tool/train.py#L165
from util import transform as t
train_transform = t.Compose([ t.RandomScale([0.9, 1.1]),
t.ChromaticAutoContrast(),
t.ChromaticTranslation(),
t.ChromaticJitter(),
t.HueSaturationTranslation()
])
数据增强的方式是随机尺寸放缩、自动对比度、色彩平移、调节对比度饱和度和随机噪声。
写一个Compose函数来进行数据增强,transform是一个list,每个里面都是一个类的instance。
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, coord, feat, label):
for t in self.transforms:
coord, feat, label = t(coord, feat, label)
return coord, feat, label
依次定义不同的数据增强方式
class RandomScale(object):
def __init__(self, scale=[0.9, 1.1], anisotropic=False):
self.scale = scale
self.anisotropic = anisotropic
def __call__(self, coord, feat, label):
scale = np.random.uniform(self.scale[0], self.scale[1], 3 if self.anisotropic else 1)
coord *= scale
return coord, feat, label
调用:
if transform:
coord, feat, label = transform(coord, feat, label)