1.celeba数据集
这是一个大规模人脸数据集官网
香港中文大学发布的,提供了百度云盘的下载,使用很方便。
总共有202,599张图片,且有图片的标注(Label)文件。 最常用的是剪裁过的图片,文件名叫img_align_celeba
2.pytorch加载
pytorch加载数据集一般分为两步,第一步是创建一个代表整个数据集的对象dataSet
from torchvision import datasets
Celeba_dataset = datasets.ImageFolder(path, transform=torchvision.transforms.ToTensor())
# #dataset=torchvision.datasets.FashionMNIST('./data/', train=True, download=True, transform=transform),
# #dataset=torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform),
# dataset=torchvision.datasets.MNIST('./data/', train=True, download=True, transform=transform),
第二步是创建一个dataloader,用于迭代一批数据,让整个数据分批训练,这里batch_size就是一批图片的大小。
data_loader = torch.utils.data.DataLoader(dset, batch_size=128, shuffle=true,drop_last=True)
3.加载报错
如果报错:
RuntimeError: Found 0 files in subfolders of: Data/celeb_data/resized_celeb/ Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
这时要注意torchvision.datasets.ImageFolder这个数据集的子类,如果用这个默认类读取图片文件,需要在该文件下再创建文件夹作为类别标签,因为它的格式是
img_file:
label1:
1.jpg
2.jpg
3.jpg
label2:
1...
所以在img_align_celeba文件夹外再套一个文件夹,把ImageFolder的path参数名改为外面那个文件夹即可.
stack_overflow上也有类似的解释:
https://stackoverflow.com/questions/56720653/image-is-in-jpeg-but-torchvision-shows-image-extension-is-unsupported