[数据增广]--AUGMIX:A SIMPLE DATA PROCESSING METHOD TO IMPROVE ROBUSTNESS AND UNCERTA

在这里插入图片描述

引 言 : \color{#FF3030}{引言:}

在这里插入图片描述
数据增广在深度学习领域已经被证明是个极为有效的提高模型效果和泛化能力的措施。一般默认的数据增广方式有:
1.图片水平、竖直翻转
2.图片旋转(如90、180、270度旋转)
3.图片缩放
4.图片裁剪(如中心裁剪,随机裁剪等方式)
5.图片加噪声(如高斯)

最新提出的数据增广方法有,如上图:
1.CutOut,随机在图片上找一个矩形区域丢掉
2.MixUp,将两张图片按照一定比率进行像素融合,融合后的图带有两张图片的标签
3.CutMix,与CutOut类似,不同的地方是将这个矩形区域填充为另一张图片的像素,本人亲测,稳定有涨点!
4.本文中AugMix,在下面详细说

创 新 点 : \color{#FF3030}{创新点:}

在这里插入图片描述
操作来说如上图,将一张原图进行变换、旋转、多色调3个分支的并行操作,然后按照一定比例融合,最后与原图按照一定比率融合。
在这里插入图片描述
上图是算法流程,大概流程就是对原图进行k个数据增广操作,然后使用k个权重进行融合得到aug图像。最终与原图进行一定比率融合得到最终augmix图像。Jensen-Shannon Divergence Consistency Loss这个损失函数是用来计算augmix后图像与原图的JS散度,需要保证augmix图像与原图的相似性。

代 码 : \color{#FF3030}{代码:}

1.AugMix:

def aug(image, preprocess): 
   """Perform AugMix augmentations and compute mixture.   
   Args: 
    image: PIL.Image input image 
    preprocess: Preprocessing function which should return a torch tensor.   
   Returns: 
     mixed: Augmented and mixed image. 
   """ 
   ws = np.float32(np.random.dirichlet([1] * args.mixture_width)) 
   m = np.float32(np.random.beta(1, 1)) 

   mix = torch.zeros_like(preprocess(image)) 
   for i in range(args.mixture_width): 
     image_aug = image.copy() 
     depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint( 1, 4) 
     for _ in range(depth): 
       op = np.random.choice(augmentations.augmentations) 
       image_aug = op(image_aug, args.aug_severity) 
     # Preprocessing commutes since all coefficients are convex 
     #  k个增广加权融合
     mix += ws[i] * preprocess(image_aug) 
   #  与原图加权融合
   mixed = (1 - m) * preprocess(image) + m * mix 
   return mixed 

2.AugMix_DataSet:

class AugMixDataset(torch.utils.data.Dataset): 
   """Dataset wrapper to perform AugMix augmentation.""" 
  
   def __init__(self, dataset, preprocess, no_jsd=False): 
     self.dataset = dataset 
     self.preprocess = preprocess 
     self.no_jsd = no_jsd 

   def __getitem__(self, i): 
     x, y = self.dataset[i] 
     # 如果不使用js损失,直接返回augmix图像和标签
     if self.no_jsd: 
       return aug(x, self.preprocess), y 
     # 如果使用js损失,则返回原图,2次augmix图像组成的tuple和标签
     else: 
       im_tuple = (self.preprocess(x), aug(x, self.preprocess), 
                   aug(x, self.preprocess)) 
       return im_tuple, y 

   def __len__(self): 
     return len(self.dataset) 

3.train:

def train(net, train_loader, optimizer, scheduler): 
   """Train for one epoch.""" 
   net.train() 
   loss_ema = 0. 
   for i, (images, targets) in enumerate(train_loader): 
    optimizer.zero_grad() 
   # 不使用js损失
   if args.no_jsd: 
       images = images.cuda() 
       targets = targets.cuda() 
       logits = net(images) 
       loss = F.cross_entropy(logits, targets) 
    else: 
       images_all = torch.cat(images, 0).cuda() 
       targets = targets.cuda() 
       logits_all = net(images_all) 
       logits_clean, logits_aug1, logits_aug2 = torch.split( 
           logits_all, images[0].size(0)) 
       #  使用js损失则包含两个损失函数
       # Cross-entropy is only computed on clean images 
       loss = F.cross_entropy(logits_clean, targets) 

       p_clean, p_aug1, p_aug2 = F.softmax( 
           logits_clean, dim=1), F.softmax( 
               logits_aug1, dim=1), F.softmax( 
                   logits_aug2, dim=1) 

       # Clamp mixture distribution to avoid exploding KL divergence 
       p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() 
       loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') + 
                     F.kl_div(p_mixture, p_aug1, reduction='batchmean') + 
                     F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. 

     loss.backward() 
     optimizer.step() 
     scheduler.step() 
     loss_ema = loss_ema * 0.1 + float(loss) * 0.9 
     if i % args.print_freq == 0: 
       print('Train Loss {:.3f}'.format(loss_ema)) 

   return loss_ema 

  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值