PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。这3个子包的具体介绍可以参考官网:https://pytorch.org/docs/master/torchvision/index.html
我的另外两篇博客对其他两个部分做了介绍分别为:
torchvision.datasetshttps://blog.csdn.net/sinat_42239797/article/details/93916790
torchvision.modelshttps://blog.csdn.net/sinat_42239797/article/details/94329987
这篇博客介绍torchvision.transformas。
torchvision.transforms这个包中包含resize、crop等常见的data augmentation操作,基本上PyTorch中的data augmentation操作都可以通过该接口实现。该包主要包含两个脚本:transformas.py和functional.py,前者定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。
pytorch的图像变换模块主要由五部分构成,分别为:
Transforms on PIL Image、Transforms on torch.Tensor、Conversion Transforms、Generic Transforms、Functional Transforms
常见的变换Transforms on PIL Image、Transforms on torch.Tensor、Conversion Transforms它们可以使用链接在一起Compose。
先举一个常用操作的简单例子感受一下:
import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transofrms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
])
Class custom_dataread(torch.utils.data.Dataset):
def __init__():
...
def __getitem__():
# use self.transform for input image
def __len__():
...
train_loader = torch.utils.data.DataLoader(
custom_dataread(transform=train_augmentation),
batch_size = batch_size, shuffle = True,
num_workers = workers, pin_memory = True)
注意
torchvision.transforms.ToTensor(),
torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225]
这两行使用使用时顺序不可以颠倒,原因是因为归一化需要是Tensor型的数据,所以要先将数据转化为Tensor型才可以进行归一化。
一般情况下我们将对图片的变换操作放到torchvision.transforms.Compose()进行组组合变换。
Transforms on PIL Image
torchvision.transforms.CenterCrop(大小)
参数介绍
size(sequence 或int) - 作物的所需输出大小。如果size是int而不是像(h,w)这样的序列,则进行正方形裁剪(大小,大小)。
torchvision.transforms.ColorJitter(亮度= 0,对比度= 0,饱和度= 0,色调= 0 )#随机更改图像的亮度,对比度和饱和度。
参数介绍
亮度(浮点数或python的元组:浮点数(最小值,最大值)) - 抖动亮度多少。从[max(0,1-brightness),1 + brightness]或给定[min,max]均匀地选择brightness_factor。应该是非负数。
对比度(浮点数或python的元组:浮点数(最小值,最大值)) - 抖动对比度多少。contrast_factor从[max(0,1-contrast),1 + contrast]或给定[min,max]中均匀选择。应该是非负数。
饱和度(浮点数或python的元组数:float (min ,max )) - 饱和度抖动多少。饱和度_因子从[max(0,1-saturation),1 + saturation]或给定[min,max]中均匀选择。应该是非负数。
色调(浮点数或python的元组:浮点数(最小值,最大值)) - 抖动色调多少。从[-hue,hue]或给定的[min,max]中均匀地选择hue_factor。应该有0 <= hue <= 0.5或-0.5 <= min <= max <= 0.5。
torchvision.transforms.FiveCrop(大小)#将给定的PIL图像裁剪为四个角和中央裁剪
参数介绍
size(sequence 或int) - 作物的所需输出大小。如果大小是int 而不是像(h,w)这样的序列,则进行大小(大小,大小)的正方形裁剪。
例子:
transform = Compose([
FiveCrop(size), # this is a list of PIL Images
Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
torchvision.transforms.RandomCrop(size,padding = None,pad_if_needed &#