手写torchvision transform

[1] 中用到 PyTorch 1.7.1 和 kornia 0.5.10 的一些图像增强,而我的机器的 cuda 版本不够新,要用 pytorch 1.4.0 和 opencv 重写。

original

  • [1] 中原本的 Augmentation 及调用
  • kornia 的 API 支持对一个 batch 操作
# import torch as T
# import torch.nn as nn
# import torchvision
# import torchvision.transforms as transforms
# import kornia.augmentation as Kg

Augmentation = nn.Sequential(
    Kg.RandomResizedCrop(size=(sz, sz)),
    Kg.RandomHorizontalFlip(p=0.5),
    Kg.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
    Kg.RandomGrayscale(p=0.2),
    Kg.RandomGaussianBlur((int(0.1 * sz), int(0.1 * sz)), (0.1, 2.0), p=0.5)
)

transform = transforms.ToTensor()

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=args.if_download, transform=transform)
trainloader = T.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

for i, data in enumerate(trainloader, 0):
    inputs = data[0].to(device)
    Ia = Augmentation(inputs)
    Ib = Augmentation(inputs)
    # ...其它东西...

rewrite

  • 用 torchvision 原有的 transforms 和 opencv 的 Gaussian blur[3,4] 重写
  • 参考 MoCo[5] 的写法和 TwoCropsTransform(MoCo 也有写自己的 GaussianBlur,但好像没用上)
  • torchvision 的 transform 对单张 image 操作,所以放在 dataset 的 transform 那,用 TwoCropsTransform 产生两张
# import random
# import numpy as np
# import cv2
# from PIL import Image
# import torch as T
# import torchvision
# import torchvision.transforms as transforms

class TwoCropsTransform:
    """https://github.com/facebookresearch/moco/blob/main/moco/loader.py#L6"""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class RandomGaussianBlur:
    """random Gaussian blur in opencv, mimicking kornia
    ref:
    - https://kornia.readthedocs.io/en/0.5.10/augmentation.module.html?highlight=RandomGaussianBlur#kornia.augmentation.RandomGaussianBlur
    - https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html#gaabe8c836e97159a9193fb0b11ac52cf1
    - https://docs.opencv.org/4.x/d2/de8/group__core__array.html#ga209f2f4869e304c82d07739337eae7c5
    """

    def __init__(self, kernel_size, sigma, border_type='reflect', p=0.5):
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        else:
            assert isinstance(kernel_size, (tuple, list)) and (len(kernel_size) == 2)
        self.kernel_size = kernel_size

        assert isinstance(sigma, (tuple, list)) and (len(kernel_size) == 2)
        self.sigma = sigma

        # cv2 does NOT support `circular` like kornia 0.5.10
        assert border_type in ["constant", "reflect", "replicate"]
        if "reflect" == border_type:
            self.border_type = cv2.BORDER_REFLECT
        elif "constant" == border_type:
            self.border_type = cv2.BORDER_CONSTANT
        elif "replicate" == border_type:
            self.border_type = cv2.BORDER_REPLICATE

        self.p = p

    def __call__(self, x):
        if random.random() >= self.p:
            return x

        x = np.array(x)
        x = cv2.GaussianBlur(x, self.kernel_size,
            sigmaX=self.sigma[0], sigmaY=self.sigma[1], borderType=self.border_type)
        x = Image.fromarray(x)
        return x


Augmentation = transforms.Compose([
    transforms.RandomResizedCrop(size=(sz, sz)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    RandomGaussianBlur((int(0.1 * sz), int(0.1 * sz)), (0.1, 2.0), p=0.5),
    transforms.ToTensor(),  # 在这里 ToTensor
])

trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=args.if_download,
    transform=TwoCropsTransform(Augmentation))  # 换掉 dataset 里的 transform
rainloader = T.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

for i, (data, _) in enumerate(trainloader, 0):
    Ia = data[0].to(device)
    Ib = data[1].to(device)
    # ...其它东西...

validation

对 [1] 改写之后,可以基本复现其在 cifar-10 32 bits 上的结果(即其默认的 showcase),所以这个改写应该是能用的。

References

  1. youngkyunJang/SPQ
  2. kornia | RandomGaussianBlur
  3. opencv | GaussianBlur
  4. opencv | BorderTypes
  5. facebookresearch/moco
  6. torchvision | RandomHorizontalFlip <- 参考 random apply 的写法,即那个 p 参数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值