看了pytorch官方提供的tutorial中transfer learning这个例子,对其中的数据读取部分很是模糊,于是仔细分析了一番,今天写一篇博客记录一下自己所看所得。
dataloader
下面这段代码最终得到了dataloader,dataloader是python中的可迭代对象,我们可以通过for循环讲数据一一取出。
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = '../data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
然而问题的关键在于我们如何获得dataloader,答案是将自定义的dataset传入torch.utils.data.DataLoader
中。至于dataloader如何使用dataset工作,我会在下次在进行分析,这次的关键是dataset的制作。
dataset
其实dataset的制作我在上次的博客中也做了分析–从python中的一些特殊方法讲到pytorch的官方例子mnist(主要针对pytorch的自定义dataset中的几个特殊函数进行说明)。
需要在自定义的dataset类做到以下几点:
1. 继承torch.utils.data.Dataset
类。
2. 重写__init__
方法、__getitem__
方法、__len__
方法以及__repr__方法
(非必须),至于每个类的作用我在上篇博客已经有很详细的讲解。
下面我们看一下分类任务专用的dataset类:torchvision.datasets.ImageFolder
:
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def pil_loader(path):# 根据地址读取图像
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class ImageFolder(DatasetFolder):
def __init__(self, root, transform=None, target_transform=None,
loader=pil_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
这个类是一个叫做DatasetFolder类的子类,主要的功能都写在了那个类中,该类的主要作用就是传递了两个额外参数:loader
和IMG_EXTENSIONS
。loder是上面定义函数pil_loader()
的引用,该函数的作用是根据传入的图像地址进行图像读取;IMG_EXTENSIONS
定义了读取图像文件的扩展名类型。其余在调用父类__init__
方法时传入的参数在最外面就已经传入,包括root
表示路径、transform
表示要对图像进行的变换。(看第一段代码传入的参数)
接下来看DatasetFolder
类的定义:
class DatasetFolder(data.Dataset):
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
从上面代码可以看出DatasetFolder
类的定义遵从了自定义的dataset类时需要遵守的几点规则。从上篇博客我们已经知道__getitem__
方法是用来获取dataset中的数据的,但这个不是本次的重点,本次重点是为何__getitem__
方法中的代码能够实现获取数据。
首先来看下面一段测试代码,看过之后就会大致明白。这段代码将关键的函数和句子放上去进行测试。
import os
# has_file_allowed_extension函数的功能是根据文件名判断该文件是否具有所需图像类型扩展名的后缀
def has_file_allowed_extension(filename, extensions):
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
# find_classes函数的功能是根据输入的存放图像的文件夹地址,得到文件夹下面有几种图像,为每种图像分配一个数字
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
# make_dataset函数会根据图像种类字典、存放图像的文件夹地址以及扩展名列表得到每个图像的地址以及种类信息组成的列表
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
root = '../data/hymenoptera_data/train'
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
# 从输出结果可以看出:classes是由存放每类图像的文件夹名字组成的列表;
# class_to_idx是由每种图像的类名和为其分配的数字组成的键值对所组成的字典;
# samples是由个数与所有类图像总个数相等的元组组成的列表,元组里面的内容就对应了每张图像的地址以及它的分类编号。
# 有了这些信息,就能够通过__getitem__方法中的前两句代码:
# path, target = self.samples[index]
# sample = self.loader(path)
# 获取到图像和其对应分类了。
print(classes)
print(class_to_idx)
print(samples)
print(len(samples))
输出内容如下:
['ants', 'bees']
{'bees': 1, 'ants': 0}
[('../data/hymenoptera_data/train/ants/0013035.jpg', 0), ('../data/hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg', 0), ('../data/hymenoptera_data/train/ants/1095476100_3906d8afde.jpg', 0), ('../data/hymenoptera_data/train/ants/1099452230_d1949d3250.jpg', 0), ('../data/hymenoptera_data/train/ants/116570827_e9c126745d.jpg', 0), ('../data/hymenoptera_data/train/ants/1225872729_6f0856588f.jpg', 0), ('../data/hymenoptera_data/train/ants/1262877379_64fcada201.jpg', 0), ('../data/hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg', 0), ('../data/hymenoptera_data/train/ants/1286984635_5119e80de1.jpg', 0), ('../data/hymenoptera_data/train/ants/132478121_2a430adea2.jpg', 0), ('../data/hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg', 0), ('../data/hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg', 0), ('../data/hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg', 0), ('../data/hymenoptera_data/train/ants/148715752_302c84f5a4.jpg', 0), ('../data/hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg', 0), ('../data/hymenoptera_data/train/ants/149244013_c529578289.jpg', 0),