【CutMix实现/MixUp实现】TorchVision 0.16发布,一行代码实现CutMix/Mixup

【CutMix实现/MixUp实现】TorchVision 0.16发布,一行代码实现CutMix/Mixup

请添加图片描述

北京时间2023年10月5日,TorchVision 0.16正式版本上线,带来了令人期待的多项更新

新的torchvision.transforms.v2 支持图像分类、分割、检测和视频任务,相较于旧版本速度提升10%-40%,对于v2.resize()速度提升高达2-4倍。
本次更新同时带来了CutMix和MixUp的图片增强,用户可以在torchvision.transforms.v2中直接调用它们,也可以通过dataloader直接载入。

如何使用新的CutMix和MixUp

  1. 首先需要引入包

    import torch
    from torchvision.datasets import FakeData
    from torchvision.transforms import v2
    
  2. torchvision.transforms.v2 的调用

    preproc = v2.Compose([
        v2.PILToTensor(),
        v2.RandomResizedCrop(size=(224, 224), antialias=True),
        v2.RandomHorizontalFlip(p=0.5),
        v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
        v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # typically from ImageNet
    ])
    
    dataset = SampleData(size=1000, num_classes=100, transform=preproc)
    
  3. 在DataLoader后引入MixUpCutMix

    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # use MixUp and CutMix
    cutmix = v2.CutMix(num_classes=NUM_CLASSES)
    mixup = v2.MixUp(num_classes=NUM_CLASSES)
    cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
    
    for images, labels in dataloader:
        # images shape: torch.Size([4, 3, 224, 224])
        # labels shape: torch.Size([4])
        images, labels = cutmix_or_mixup(images, labels)
        # after cutmix_or_mixup()
       	# images shape: torch.Size([4, 3, 224, 224])
        # labels shape: torch.Size([4, 100])
        # <rest of the training loop here>
    

    由此可见,cutmix和mixup的加入,使得网络Ground Truth的输入从 [batch_size] 变成了 [batch_size, num_classes] ,原本是index类型的标签转换为了one-hot类型的标签。

    下面是一个经典的cat vs dog的例子:
    请添加图片描述

更高效地使用CutMix和MixUp

上文提到的在 DataLoader 之后调用是使用 CutMix 和 MixUp 的最简单方法,但一个缺点是它没有利用 DataLoader 多线程处理。为此,我们可以将这些转换作为collation function的一部分传递。

具体的例子如下:

from torch.utils.data import default_collate


def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

# 在此处调用collate_fn
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

for images, labels in dataloader:
    # Your code here
    # 无需在此处调用cutmix_or_mixup,因为它已经成为了DataLoader的一部分
    # <rest of the training loop here>
    break

参考链接
Github.com | TorchVision v0.16.0
官方文档
pytorch.org | HOW TO USE CUTMIX AND MIXUP

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值