从文件夹中加载Dataset
import torch.utils.data as data
from PIL import Image
import os
def default_loader(path):
return Image.open(path).convert('RGB')
##type(return):<class 'PIL.Image.Image'>
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
#返回每个图像文件的路径(list)
def make_dataset(dir):
images = [] #images保存图片文件的路径(不只是名字)
assert os.path.isdir(dir), '%s is not a valid direntory' % dir
# os.wark()方法
# https://blog.csdn.net/weixin_44887621/article/details/118810180
# os.wark()实际返回产生生成器<class 'generator'>
# sorted(os.wark())返回有若干三元组组成的list
# sorted()保证每次数据顺序一样
for root, dirs, files in sorted(os.walk(dir)):
for fname in files:
if is_image_file(fname):
path = os.path.join(root,fname)
images.append(path)
return images
class ImageFolder(data.Dataset):
def __init__(self, root='makeup', transform=None,return_path=False,
loder=default_loader):
imgs = sorted(make_dataset(root))
if len(imgs) == 0:
raise(RuntimeError("Found 0 image in:" + root + "\n"
"Supported image extensions are:" +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_path = return_path
self.loader = default_loader
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_path:
return img, path
else:
return img