1. 首先导入所需要的包,其中torchvision包主要实现数据的处理、导入和预览
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
2.torchvision中的datasets可以实现对数据集的下载,例如MNIST、COCO、ImageNet、CIFCAR,代码如下:
# Download the datasets
data_train = datasets.MNIST(root = "./data/",
transform=transform,
train = True,
download = True)
data_test = datasets.MNIST(root = "./data/",
transform=transform,
train = False)
其中:root指定了数据集下载后的存放路径,train指定了当前为测试集还是训练集,transfrom指定了所应用的变换,其代码如下:
# Set the transform format
transform = transforms.Compose([transforms.ToTensor(),
])
3.transforms的具体应用:
transforms主要负责对载入的数据进行变换,主要是变为Tensor类型,以及归一化和大小缩放的操作。除此以外,当数据集比较有限时,可以通过变换训练集生成更多的数据进行训练。(数据增强)
上一段的Compose可以看作一种容器,所传入的是一个列表,能够容纳多种数据变换。常用的数据变换操作有:
torchvision.transforms.Resize(h,w):对载入的图片数据按照需求大小进行缩放;
torchvision.transforms.Scale(h,w) : 同上
torchvision.transforms.CenterCrop(h,w) : 以图片中心为参考点,对载入的图片进行裁剪
torchvision.transforms.RondomHorizontalFlip(rate) : 对载入图片随机水平翻转
torchvision.transforms.RondomVerticalFlip(rate) : 对载入图片随机垂直翻转
torchvision.transforms.ToTensor()
torchvision.transforms.ToPILImage()