ImageFolder 用于读取文件夹内的图片与类别,生成Map-style datasets形式的数据集,以便DataLoader迭代。ImageFolder使用格式:
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
但是由于各种原因,如原生数据存储格式,某些操作系统(如centos)中单个目录中最大文件数量存在限制(hdf5 yyds)等等,使得数据的存储是分开的,如:
sub_root1/dog/xxx.png
sub_root1/dog/xxy.png
sub_root1/dog/[...]/xxz.png
sub_root2/cat/123.png
sub_root2/cat/nsdf3.png
sub_root2/cat/[...]/asd932_.png
需要分别读取,再合并ImageFolder,功能代码如下:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
def merge_datasets(dataset, sub_dataset):
'''
需要合并的Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
'''
# 合并 classes
dataset.classes.extend(sub_dataset.classes)
dataset.classes = sorted(list(set(dataset.classes)))
# 合并 class_to_idx
dataset.class_to_idx.update(sub_dataset.class_to_idx)
# 合并 samples
dataset.samples.extend(sub_dataset.samples)
# 合并 targets
dataset.targets.extend(sub_dataset.targets)
验证是否可行:
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
paths = ["E:\\datasets\\office31\\amazon\\asub1", "E:\\datasets\\office31\\amazon\\asub2", "E:\datasets\\office31\\amazon\\asub3"]
dataset = ImageFolder(root=paths[0], transform=transform)
for i in range(len(paths) - 1):
sub_dataset = ImageFolder(paths[i + 1])
merge_datasets(dataset, sub_dataset)
dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=False)
for data in dataloader:
images, targets = data
print(targets)