[1] 中用到 PyTorch 1.7.1 和 kornia 0.5.10 的一些图像增强,而我的机器的 cuda 版本不够新,要用 pytorch 1.4.0 和 opencv 重写。
original
- [1] 中原本的 Augmentation 及调用
- kornia 的 API 支持对一个 batch 操作
# import torch as T
# import torch.nn as nn
# import torchvision
# import torchvision.transforms as transforms
# import kornia.augmentation as Kg
Augmentation = nn.Sequential(
Kg.RandomResizedCrop(size=(sz, sz)),
Kg.RandomHorizontalFlip(p=0.5),
Kg.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
Kg.RandomGrayscale(p=0.2),
Kg.RandomGaussianBlur((int(0.1 * sz), int(0.1 * sz)), (0.1, 2.0), p=0.5)
)
transform = transforms.ToTensor()
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=args.if_download, transform=transform)
trainloader = T.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)
for i, data in enumerate(trainloader, 0):
inputs = data[0].to(device)
Ia = Augmentation(inputs)
Ib = Augmentation(inputs)
# ...其它东西...
rewrite
- 用 torchvision 原有的 transforms 和 opencv 的 Gaussian blur[3,4] 重写
- 参考 MoCo[5] 的写法和 TwoCropsTransform(MoCo 也有写自己的 GaussianBlur,但好像没用上)
- torchvision 的 transform 对单张 image 操作,所以放在 dataset 的 transform 那,用
TwoCropsTransform
产生两张
# import random
# import numpy as np
# import cv2
# from PIL import Image
# import torch as T
# import torchvision
# import torchvision.transforms as transforms
class TwoCropsTransform:
"""https://github.com/facebookresearch/moco/blob/main/moco/loader.py#L6"""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
class RandomGaussianBlur:
"""random Gaussian blur in opencv, mimicking kornia
ref:
- https://kornia.readthedocs.io/en/0.5.10/augmentation.module.html?highlight=RandomGaussianBlur#kornia.augmentation.RandomGaussianBlur
- https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html#gaabe8c836e97159a9193fb0b11ac52cf1
- https://docs.opencv.org/4.x/d2/de8/group__core__array.html#ga209f2f4869e304c82d07739337eae7c5
"""
def __init__(self, kernel_size, sigma, border_type='reflect', p=0.5):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
else:
assert isinstance(kernel_size, (tuple, list)) and (len(kernel_size) == 2)
self.kernel_size = kernel_size
assert isinstance(sigma, (tuple, list)) and (len(kernel_size) == 2)
self.sigma = sigma
# cv2 does NOT support `circular` like kornia 0.5.10
assert border_type in ["constant", "reflect", "replicate"]
if "reflect" == border_type:
self.border_type = cv2.BORDER_REFLECT
elif "constant" == border_type:
self.border_type = cv2.BORDER_CONSTANT
elif "replicate" == border_type:
self.border_type = cv2.BORDER_REPLICATE
self.p = p
def __call__(self, x):
if random.random() >= self.p:
return x
x = np.array(x)
x = cv2.GaussianBlur(x, self.kernel_size,
sigmaX=self.sigma[0], sigmaY=self.sigma[1], borderType=self.border_type)
x = Image.fromarray(x)
return x
Augmentation = transforms.Compose([
transforms.RandomResizedCrop(size=(sz, sz)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
RandomGaussianBlur((int(0.1 * sz), int(0.1 * sz)), (0.1, 2.0), p=0.5),
transforms.ToTensor(), # 在这里 ToTensor
])
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=args.if_download,
transform=TwoCropsTransform(Augmentation)) # 换掉 dataset 里的 transform
rainloader = T.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)
for i, (data, _) in enumerate(trainloader, 0):
Ia = data[0].to(device)
Ib = data[1].to(device)
# ...其它东西...
validation
对 [1] 改写之后,可以基本复现其在 cifar-10 32 bits 上的结果(即其默认的 showcase),所以这个改写应该是能用的。