加载指定目录下的图片
工于利其事,必先善其器。加载训练数据集是深度学习基础中的基础,因此在这里将加载数据集这一过程封装好,后续训练GAN时就可以快速进行模型测试了。
为了能够与大部分的pytorch代码兼容,最好是采用pytorch官方推荐的方式,通过继承Dataset实现加载数据的过程。对于GAN的训练,由于大部分情况下不需要与图片对应的标签,所以加载数据的情况可以分为:
- 图片都放在某一个目录下
- 图片放在某一个目录及其子目录下
因此可以先找到目录下所有需要处理的图片路径,然后读取图片并转化为Tensor的形式。代码如下,以GAN常用数据集FFHQ为例:
import pathlib
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
class ImagesFolder(Dataset):
def __init__(self, root, transform=None,distributed=False,open_mode=None):
self.images_path = self.getImagesPath(root,distributed)
self.transform