ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
图片结构如下所示:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
ImageFolder主要有4个参数:
- root:在root的路径下寻找图片
- transform:
- target_transform:对label的转换
- loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
先不加transform
没加transform,所有显示的还是PIL image格式
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt
dataset = ImageFolder('data/train/' ) #把cat和dog两个文件夹下的图片作为一个list
print(len(dataset)) #数据集长度
print(dataset.class_to_idx) ##cat文件夹对应label:0,dog文件夹对应label:1
print(dataset.imgs) ##打印所有文件的路径和对应的label
print(dataset[0][1])# 第一维是第几张图。第二维为1时返回label
print(dataset[0][0]) #第一维是第几张图。第二维为0时返回图片数据
plt.imshow(dataset[666][0])
plt.show()
显示结果:
加入transform
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt
transform = transforms.Compose([transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
dataset = ImageFolder('data/train/',transform = transform )
to_img = transforms.ToPILImage() #为了可视化转为tensor再转回图片
a = to_img(dataset[666][0]*0.5+0.5) # 0.5是标准差和均值的近似
plt.imshow(a)
plt.show()
这里就不贴图了。
关于transform可见下面我的另一篇博文:torchvision包