Pytorch加载图像数据集需要两步,首先需要使用**torchvision.datasets.ImageFolder()读取图像,然后再使用torch.utils.data.DataLoader()**加载数据集。
ImageFolder
torchvision.datasets.ImageFolder,一个通用的数据加载器,数据集中的数据以以下方式组织。
root/dog/xxx.png
root/dog/xxy.png
root/dog/[…]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[…]/asd932_.png
ImageFolder类的定义如下:
class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
Args:
- root(string) :Root directory path.
- transform(callable, optional):A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
- target_transform(callable, optional):A function/transform that takes in the target and transforms it.
- loader(callable, optional):A function to load an image given its path.
- is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)
Attributes:
- classes (list): List of the class names sorted alphabetically.
- class_to_idx (dict): Dict with items (class_name, class_index).
- imgs (list): List of (image path, class_index) tuples
下面代码展示了如何用ImageFolder去加载数据,用Dataloader构建可迭代的数据装载器。
import torch
import torchvision
from torch.utils.data import Dataloader
import torchvision.transforms as transforms
data_transforms ={
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5],[.5, .5, .5])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5],[.5, .5, .5])
])
}
# ImageFolder 通用的加载器
dataset = torchvision.datasets.ImageFolder(root, transform=data_trainsforms['train'])
# 构建可迭代的数据装载器
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for data, label in inputs:
.......
有时候,不仅仅加载图像数据和label,还需要加载图像的路径,那么需要自定义类,扩展torchvision.datasets.ImageFolder类,代码示例如下所示。
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
# 扩展torchvision.datasets.ImageFolder,自定义数据集使其包含图像路径
def __getitem__(self, index):
# ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).get__item__(index)
# 图像路径
path = self.imgs[index][0]
# 构造一个新的tuple使其包括origin和图像路径
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
dataset = ImageFolderWithPaths(root, transform=data_trainsforms['train'])
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for datas, label, paths in inputs:
.......
Dataloader
torch.utils.data.Dataset, 构建可迭代的数据装载器。组合数据集和采样器,并在数据集上提供单线程或多进程迭代器。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
参数:
- dataset (Dataset) – 加载数据的数据集。
- batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
- shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
- sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
- collate_fn (callable, optional) – 自定义处理数据并返回
- pin_memory (bool, optional) – True 代表将数据Tensor放入CUDA的pin存储
- drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)