PyTorch(一)之 torchvision 加载数据

原创博客,转载请注明出处!

PyTorch是一个最近出的功能比较强大的框架。

torchvision和torch包是PyTorch框架比较重要的两个包,其中torchvision包括下面四部分

1. torchvision.datasets : 图片、视频等数据集的加载器
2. torchvision.models : 常见网络模型的定义,如Alexnet、VGG、Resnet以及它们的与训练模型
3. torchvision.transforms : 常见的图像转换工具,如随机裁剪、旋转等
4. torchvision.utils : 工具类,如保存张量(3 x h x w)作为图像到磁盘,给一个小批量创建一个图像网格等

 

准备工作:

数据可以有两种方式存放:

第一种:图片文件夹+txt文档

                通过txt文档映射它们的关系。

第二种:训练集和测试集分开,且每一类文件都放在同一子目录下

                即目录下一般包括两个文件夹:train和val,每个文件件下面包含N个子文件夹,N是你的分类类别数,且每个子文件夹里    存放的就是这个类别的图像。比如N = 3 ,那么文件夹目录可以是这样

data/

         train/

                man/xxxx.jpg

                man/yyyy.jpg

                dog/111.jpg

                cat/222.jpg

                ...

         val/

              man/111.jpg

              dog/222.jpg

              cat/111.jpg

              ...

 

那么如何利用它们来加载数据呢?

1、先看torchvision.transforms

torchvision.transforms包含了常见的图像变化(预处理)操作.这些变化可以用torchvision.transforms.Compose链接在一起. 

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]

其中Compose方法是用来管理所有transforms操作的,其它常见方法如下:

ToTensor()是把图片数据转换成张量并转化范围在[0,1],

Normalize(mean,std)是归一化的方法,mean = (R, G, B),std = (R, G, B),如[0.485, 0.456, 0.406]和[0.229, 0.224, 0.225],

Resize(size)是将输入的PIL图像调整为给定的大小。参数可以int,也可以是int的元组(h,w),

CenterCrop(size)是将给定的 PIL.Image 进行中心切割,得到给定的 size,size 可以是 tuple,(target_height, target_width)。size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。

RandomCrop(size, padding=0)也是切割,不过切割中心点的位置随机选取。size 可以是 tuple 也可以是 Integer。

RandomHorizontalFlip(size, interpolation=2)是先将给定的 PIL.Image 随机切,然后再 resize 成给定的 size 大小。

RandomHorizontalFlip()是随机水平翻转给定的 PIL.Image,概率为 0.5。

RandomVerticalFlip()是随机垂直翻转给定的 PIL.Image,概率为 0.5。

ToPILImage()是将 shape 为 (C, H, W) 的 Tensor 或者 shape 为 (H, W, C) 的 numpy.ndarray 转换成 PIL.Image,值不变。

FiveCrop(size)是将给定的PIL图像剪裁成四个角落区域和中心区域。

Pad(padding, fill=0, padding_mode=‘constant’)是对给定的PIL图像的边缘进行填充,填充的数值为给定填充数值。

RandomAffine(degrees, translate=None, scale=None)是保持中心不变的对图片进行随机仿射变化。

RandomApply(transforms, p=0.5)是随机选取变换中(各种变换存储在列表中)的其中一个,同时给定一定的概率。

 

2、再看torchvision.datasets

torchvision.datasets是继承torch.utils.data.Dataset的子类. 因此,可以使用torch.utils.data.DataLoader对它们进行多线程处理

datasets下面有个datasets.ImageFolder方法可以实现数据导入,

ImageFolder(root,transform=None,target_transform=None,loader=default_loader)

root : 在指定的root路径下面寻找图片 
transform: 接收PIL图像的函数/转换并返回已转换的版本。 可以直接使用上面的Compose方法组合需要的变换
target_transform :对label进行变换 
loader: 指定加载图片的函数,默认操作是读取PIL image对象

 

3、接下来看torch.utils.data

有两个很重要的data.Datasetdata.DataLoaderdata.Dataset是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。比如前面说的torchvision.datasets.ImageFolder类就是继承它的,data.DataLoader也是继承它的。

前面torchvision.datasets.ImageFolder只是返回list,list是不能作为模型输入的,因此在PyTorch中需要用到data.DataLoader来将list类型的输入数据封装成Tensor数据格式,以备模型使用。对图像和标签分别封装成一个Tensor。

 

看完整代码

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), #Random cutting of an image (224, 224) from the original image
        transforms.RandomHorizontalFlip(), #Reversal at 0.5 probability level
        transforms.ToTensor(),  #Convert a PIL. Image with a range of [0,255] or numpy. ndarray to a shape of [C, H, W], and a FloadTensor with a range of [0,1.0].
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #Normalization
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

#image data file
data_root = ''
image_datasets = {x: datasets.ImageFolder(os.path.join(data_root, x),
                                          data_transforms[x]) for x in ['train', 'val']}
# wrap your data and label into Tensor
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                             batch_size=10,
                                             shuffle=True,
                                             num_workers=4) for x in ['train', 'val']}

最后的dataloaders就是一个Variable数据类型,可以作为模型的输入了。

 

  • 8
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值