引 言 : \color{#FF3030}{引言:} 引言:
数据增广在深度学习领域已经被证明是个极为有效的提高模型效果和泛化能力的措施。一般默认的数据增广方式有:
1.图片水平、竖直翻转
2.图片旋转(如90、180、270度旋转)
3.图片缩放
4.图片裁剪(如中心裁剪,随机裁剪等方式)
5.图片加噪声(如高斯)
…
最新提出的数据增广方法有,如上图:
1.CutOut,随机在图片上找一个矩形区域丢掉
2.MixUp,将两张图片按照一定比率进行像素融合,融合后的图带有两张图片的标签
3.CutMix,与CutOut类似,不同的地方是将这个矩形区域填充为另一张图片的像素,本人亲测,稳定有涨点!
4.本文中AugMix,在下面详细说
创 新 点 : \color{#FF3030}{创新点:} 创新点:
操作来说如上图,将一张原图进行变换、旋转、多色调3个分支的并行操作,然后按照一定比例融合,最后与原图按照一定比率融合。
上图是算法流程,大概流程就是对原图进行k个数据增广操作,然后使用k个权重进行融合得到aug图像。最终与原图进行一定比率融合得到最终augmix图像。Jensen-Shannon Divergence Consistency Loss这个损失函数是用来计算augmix后图像与原图的JS散度,需要保证augmix图像与原图的相似性。
代 码 : \color{#FF3030}{代码:} 代码:
1.AugMix:
def aug(image, preprocess):
"""Perform AugMix augmentations and compute mixture.
Args:
image: PIL.Image input image
preprocess: Preprocessing function which should return a torch tensor.
Returns:
mixed: Augmented and mixed image.
"""
ws = np.float32(np.random.dirichlet([1] * args.mixture_width))
m = np.float32(np.random.beta(1, 1))
mix = torch.zeros_like(preprocess(image))
for i in range(args.mixture_width):
image_aug = image.copy()
depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint( 1, 4)
for _ in range(depth):
op = np.random.choice(augmentations.augmentations)
image_aug = op(image_aug, args.aug_severity)
# Preprocessing commutes since all coefficients are convex
# k个增广加权融合
mix += ws[i] * preprocess(image_aug)
# 与原图加权融合
mixed = (1 - m) * preprocess(image) + m * mix
return mixed
2.AugMix_DataSet:
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix augmentation."""
def __init__(self, dataset, preprocess, no_jsd=False):
self.dataset = dataset
self.preprocess = preprocess
self.no_jsd = no_jsd
def __getitem__(self, i):
x, y = self.dataset[i]
# 如果不使用js损失,直接返回augmix图像和标签
if self.no_jsd:
return aug(x, self.preprocess), y
# 如果使用js损失,则返回原图,2次augmix图像组成的tuple和标签
else:
im_tuple = (self.preprocess(x), aug(x, self.preprocess),
aug(x, self.preprocess))
return im_tuple, y
def __len__(self):
return len(self.dataset)
3.train:
def train(net, train_loader, optimizer, scheduler):
"""Train for one epoch."""
net.train()
loss_ema = 0.
for i, (images, targets) in enumerate(train_loader):
optimizer.zero_grad()
# 不使用js损失
if args.no_jsd:
images = images.cuda()
targets = targets.cuda()
logits = net(images)
loss = F.cross_entropy(logits, targets)
else:
images_all = torch.cat(images, 0).cuda()
targets = targets.cuda()
logits_all = net(images_all)
logits_clean, logits_aug1, logits_aug2 = torch.split(
logits_all, images[0].size(0))
# 使用js损失则包含两个损失函数
# Cross-entropy is only computed on clean images
loss = F.cross_entropy(logits_clean, targets)
p_clean, p_aug1, p_aug2 = F.softmax(
logits_clean, dim=1), F.softmax(
logits_aug1, dim=1), F.softmax(
logits_aug2, dim=1)
# Clamp mixture distribution to avoid exploding KL divergence
p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
loss.backward()
optimizer.step()
scheduler.step()
loss_ema = loss_ema * 0.1 + float(loss) * 0.9
if i % args.print_freq == 0:
print('Train Loss {:.3f}'.format(loss_ema))
return loss_ema