从pytorch的transfer learning tutorial讲分类任务的数据读取(深入分析torchvision.datasets.ImageFolder源码)

本文深入分析PyTorch官方transfer learning教程中的数据读取部分,特别是torchvision.datasets.ImageFolder的源码。通过讨论dataloader和dataset的运作,解释如何自定义dataset,并展示如何读取本地图像进行分类任务。
摘要由CSDN通过智能技术生成

看了pytorch官方提供的tutorial中transfer learning这个例子,对其中的数据读取部分很是模糊,于是仔细分析了一番,今天写一篇博客记录一下自己所看所得。

dataloader

下面这段代码最终得到了dataloader,dataloader是python中的可迭代对象,我们可以通过for循环讲数据一一取出。

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = '../data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

  然而问题的关键在于我们如何获得dataloader,答案是将自定义的dataset传入torch.utils.data.DataLoader中。至于dataloader如何使用dataset工作,我会在下次在进行分析,这次的关键是dataset的制作。

dataset

  其实dataset的制作我在上次的博客中也做了分析–从python中的一些特殊方法讲到pytorch的官方例子mnist(主要针对pytorch的自定义dataset中的几个特殊函数进行说明)
  需要在自定义的dataset类做到以下几点:
1. 继承torch.utils.data.Dataset类。
2. 重写__init__方法、__getitem__方法、__len__方法以及__repr__方法(非必须),至于每个类的作用我在上篇博客已经有很详细的讲解。
  下面我们看一下分类任务专用的dataset类:torchvision.datasets.ImageFolder

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

def pil_loader(path):# 根据地址读取图像
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

class ImageFolder(DatasetFolder):
    def __init__(self, root, transform=None, target_transform=None,
                 loader=pil_loader):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                          transform=transform,
                                          target_transform=target_transform)
        self.imgs = self.samples

  这个类是一个叫做DatasetFolder类的子类,主要的功能都写在了那个类中,该类的主要作用就是传递了两个额外参数:loaderIMG_EXTENSIONS。loder是上面定义函数pil_loader()的引用,该函数的作用是根据传入的图像地址进行图像读取;IMG_EXTENSIONS定义了读取图像文件的扩展名类型。其余在调用父类__init__方法时传入的参数在最外面就已经传入,包括root表示路径、transform表示要对图像进行的变换。(看第一段代码传入的参数)
  接下来看DatasetFolder类的定义:

class DatasetFolder(data.Dataset):

    def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        classes, class_to_idx = find_classes(root)
        samples = make_dataset(root, class_to_idx, extensions)
        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

  从上面代码可以看出DatasetFolder类的定义遵从了自定义的dataset类时需要遵守的几点规则。从上篇博客我们已经知道__getitem__方法是用来获取dataset中的数据的,但这个不是本次的重点,本次重点是为何__getitem__方法中的代码能够实现获取数据。
  首先来看下面一段测试代码,看过之后就会大致明白。这段代码将关键的函数和句子放上去进行测试。

import os

# has_file_allowed_extension函数的功能是根据文件名判断该文件是否具有所需图像类型扩展名的后缀
def has_file_allowed_extension(filename, extensions):
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)

# find_classes函数的功能是根据输入的存放图像的文件夹地址,得到文件夹下面有几种图像,为每种图像分配一个数字
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

# make_dataset函数会根据图像种类字典、存放图像的文件夹地址以及扩展名列表得到每个图像的地址以及种类信息组成的列表
def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

root = '../data/hymenoptera_data/train'
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
# 从输出结果可以看出:classes是由存放每类图像的文件夹名字组成的列表;
# class_to_idx是由每种图像的类名和为其分配的数字组成的键值对所组成的字典;
# samples是由个数与所有类图像总个数相等的元组组成的列表,元组里面的内容就对应了每张图像的地址以及它的分类编号。
# 有了这些信息,就能够通过__getitem__方法中的前两句代码:
#                                               path, target = self.samples[index]
#                                               sample = self.loader(path)
#                                               获取到图像和其对应分类了。
print(classes)
print(class_to_idx)
print(samples)
print(len(samples))

  输出内容如下:

['ants', 'bees']
{'bees': 1, 'ants': 0}
[('../data/hymenoptera_data/train/ants/0013035.jpg', 0), ('../data/hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg', 0), ('../data/hymenoptera_data/train/ants/1095476100_3906d8afde.jpg', 0), ('../data/hymenoptera_data/train/ants/1099452230_d1949d3250.jpg', 0), ('../data/hymenoptera_data/train/ants/116570827_e9c126745d.jpg', 0), ('../data/hymenoptera_data/train/ants/1225872729_6f0856588f.jpg', 0), ('../data/hymenoptera_data/train/ants/1262877379_64fcada201.jpg', 0), ('../data/hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg', 0), ('../data/hymenoptera_data/train/ants/1286984635_5119e80de1.jpg', 0), ('../data/hymenoptera_data/train/ants/132478121_2a430adea2.jpg', 0), ('../data/hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg', 0), ('../data/hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg', 0), ('../data/hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg', 0), ('../data/hymenoptera_data/train/ants/148715752_302c84f5a4.jpg', 0), ('../data/hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg', 0), ('../data/hymenoptera_data/train/ants/149244013_c529578289.jpg', 0), 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值