如何在 PyTorch 上加载自定义数据集

引言

当你想要构建一个机器学习模型时,需要做的第一件事就是准备数据集。当数据是表格格式时,很容易准备它们。但是像图片这样的数据呢?

图像与表格数据的格式不同,这种格式有很多数据表示方式。

有些人根据相应的类将图像放到一个文件夹中,有些人将元数据放在表格格式中,用于描述图像文件名及其标签。

当数据集处于第一种格式时,我们可以使用 torch.data.utils 库中的一个名为 imageolder 的类更容易地加载数据集。

但是大多数时候,图像数据集具有第二种格式,它由元数据和图像文件夹组成。因此,我们必须付出一些努力来准备数据集。

例如,您希望使用深度学习构建一个图像分类器,它由一个元数据组成,如下所示:

如你所见,数据集由图像 id 和标签组成。在这种情况下,图像 id 还表示.jpg格式的文件名,而标签则采用一键编码格式。

我们如何加载数据集,使模型可以读取图像和它们的标签?您可以做的是构建一个可以包含它们的对象。

在本文中,我将向你展示如何使用 PyTorch 加载包含元数据的图像数据集。

预处理元数据

我们要做的第一件事就是对元数据进行预处理。从上图中我们可以看到,数据集不包含图像文件名,还包含了标签。

但值得庆幸的是,图像 id 还可以通过在 id 中添加 .jpg 来表示图像文件名。生成图像文件名的代码如下所示:

import pandas as pd
# ASSUME THAT YOU RUN THE CODE ON KAGGLE NOTEBOOK
path = '/kaggle/input/plant-pathology-2020-fgvc7/'
img_path = path + 'images'
# LOAD THE DATASET
train_df = pd.read_csv(path + 'train.csv')
test_df = pd.read_csv(path + 'test.csv')
sample = pd.read_csv(path + 'sample_submission.csv')
# GET THE IMAGE FILE NAME
train_df['img_path'] = train_df['image_id'] + '.jpg'
test_df['img_path'] = test_df['image_id'] + '.jpg'
train_df.head()

结果如下:

在获得图像文件名之后,现在我们可以将标签取消中心位置,变成单一列。代码如下:

# UNPIVOT / MELT THE LABELS
train_label = train_df.melt(id_vars=['image_id', 'img_path'])
# FILTER THE DATA
train_label = train_label[train_label['value'] == 1]
# GET THE IMAGE ID NUMBER
train_label['id'] = [int(i[1]) for i in train_label['image_id'].str.split('_')]
# RESET THE INDEX
train_label = train_label.sort_values('id').reset_index()
# ADD THE LABEL TO THE DATASET
train_df['label'] = train_label['variable']
# REFORMAT THE DATASET
train_df = train_df[train_df.columns[[0, 5, 1, 2, 3, 4, 6]]]
print(train_label.shape)
train_df.head()

结果如下:

因为机器学习模型只能读取数字,所以我们必须将标签编码为数字。代码如下:

from sklearn.preprocessing import LabelEncoder
# Encode the label
le = LabelEncoder()
label_encoded = le.fit_transform(train_df['label'])
train_df['label_encoded'] = label_encoded
# Taking the class name
label_names = label_encoded.classes_
train_df.head()

结果如下:

在我们预处理完元数据之后,就可以进入下一步了。

使用数据集类构建图像容器

下一步是为我们的图像和标签构建一个容器对象。我们需要构建这个对象的原因是为了使我们将数据加载到深度学习模型的任务更加容易。因此,我们可以通过索引访问图像及其标签。

要创建这个对象,我们可以使用一个名为 Dataset 的类,该类来自 torch.utils.data library。此类是一个抽象类,因为它包含尚未实现的函数或方法。因此,我们可以根据自己的需要来实现这些功能。

我们需要实现的功能是:

  • __init__ function 

  • __len__ function

  • __getitem__ function

__init__ 函数将从其类中初始化一个对象,并从用户那里收集参数。__len__ 函数将返回数据集的长度。最后, __getitem__ 函数将帮助我们利用索引返回数据观测值。

通过理解类及其相应的函数,现在我们可以实现代码。我将使用名为 PathologyPlantsDataset 的类名,它继承 Dataset 类的函数。代码如下:

class PathologyPlantsDataset(Dataset):
  """
  The Class will act as the container for our dataset. It will take your dataframe, the root path, and also the transform function for transforming the dataset.
  """
    def __init__(self, data_frame, root_dir, transform=None):
        self.data_frame = data_frame
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        # Return the length of the dataset
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        # Return the observation based on an index. Ex. dataset[0] will return the first element from the dataset, in this case the image and the label.
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 1])
        image = Image.open(img_name)
        label = self.data_frame.iloc[idx, -1]
        
        if self.transform:
            image = self.transform(image)
    
        return (image, label)

创建类之后,现在我们可以从中构建对象。创建对象时,我们将设置由数据集,根目录和转换函数组成的参数。代码如下:

# INSTANTIATE THE OBJECT
pathology_train = PathologyPlantsDataset(
    data_frame=train_part,
    root_dir=path + 'images',
    transform=transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

现在,我们可以使用对象提取图像及其标签。如前所述,为了从数据访问观察值,我们可以使用索引。

例如,当我们要访问数据集的第三行(索引为2)时,可以使用pathology_train [2]进行访问。

下面给大家展示一个例子,说明如何使用 patrology_train 变量来可视化结果:

temp_img, temp_lab = pathology_train[2]
plt.imshow(temp_img.numpy().transpose((1, 2, 0)))
plt.title(label_names[temp_lab])
plt.axis('off')
plt.show()

结果如下:

总结

很好!现在我们已经实现了这个对象,它可以更容易地为我们的深度学习模型加载数据了!希望同学们可以用自己的数据集尝试一下~

·  END  ·

HAPPY LIFE

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值