【CutMix实现/MixUp实现】TorchVision 0.16发布,一行代码实现CutMix/Mixup
北京时间2023年10月5日,TorchVision 0.16正式版本上线,带来了令人期待的多项更新
新的torchvision.transforms.v2
支持图像分类、分割、检测和视频任务,相较于旧版本速度提升10%-40%,对于v2.resize()
速度提升高达2-4倍。
本次更新同时带来了CutMix和MixUp的图片增强,用户可以在torchvision.transforms.v2
中直接调用它们,也可以通过dataloader
直接载入。
如何使用新的CutMix和MixUp
-
首先需要引入包
import torch from torchvision.datasets import FakeData from torchvision.transforms import v2
-
torchvision.transforms.v2
的调用preproc = v2.Compose([ v2.PILToTensor(), v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet ]) dataset = SampleData(size=1000, num_classes=100, transform=preproc)
-
在DataLoader后引入
MixUp
和CutMix
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # use MixUp and CutMix cutmix = v2.CutMix(num_classes=NUM_CLASSES) mixup = v2.MixUp(num_classes=NUM_CLASSES) cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) for images, labels in dataloader: # images shape: torch.Size([4, 3, 224, 224]) # labels shape: torch.Size([4]) images, labels = cutmix_or_mixup(images, labels) # after cutmix_or_mixup() # images shape: torch.Size([4, 3, 224, 224]) # labels shape: torch.Size([4, 100]) # <rest of the training loop here>
由此可见,cutmix和mixup的加入,使得网络Ground Truth的输入从 [batch_size] 变成了 [batch_size, num_classes] ,原本是index类型的标签转换为了one-hot类型的标签。
下面是一个经典的cat vs dog的例子:
更高效地使用CutMix和MixUp
上文提到的在 DataLoader 之后调用是使用 CutMix 和 MixUp 的最简单方法,但一个缺点是它没有利用 DataLoader 多线程处理。为此,我们可以将这些转换作为collation function
的一部分传递。
具体的例子如下:
from torch.utils.data import default_collate
def collate_fn(batch):
return cutmix_or_mixup(*default_collate(batch))
# 在此处调用collate_fn
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
for images, labels in dataloader:
# Your code here
# 无需在此处调用cutmix_or_mixup,因为它已经成为了DataLoader的一部分
# <rest of the training loop here>
break
参考链接
Github.com | TorchVision v0.16.0
官方文档
pytorch.org | HOW TO USE CUTMIX AND MIXUP