小白的进阶之路系列之四----人工智能从初步到精通pytorch自定义数据集上

前言

中,我们研究了如何在PyTorch (FashionMNIST)的内置数据集上构建计算机视觉模型。

在机器学习的许多不同问题上,我们采取的步骤是相似的。

找到一个数据集,将数据集转换为数字,建立一个模型(或找到一个现有的模型),以在这些数字中找到可用于预测的模式。

PyTorch有许多内置数据集,用于大量的机器学习基准测试,但是,您通常希望使用自己的自定义数据集。

什么是自定义数据集?

自定义数据集是与您正在处理的特定问题相关的数据集合。

本质上,自定义数据集几乎可以由任何东西组成。

例如,如果我们正在构建一个像nutrition这样的食物图像分类应用程序,我们的自定义数据集可能是食物图像。

或者,如果我们试图建立一个模型来分类网站上基于文本的评论是正面还是负面,我们的自定义数据集可能是现有客户评论及其评级的示例。

或者,如果我们试图构建一个声音分类应用程序,我们的自定义数据集可能是带有样本标签的声音样本。

或者,如果我们试图为在我们网站上购买商品的客户建立一个推荐系统,我们的自定义数据集可能是其他人购买的产品的示例。

但有时这些现有的功能可能还不够。

在这种情况下,我们总是可以创建torch.utils.data.Dataset的子类,并根据自己的喜好定制它。

本篇涵盖的内容

我们将用之前的文章中应用PyTorch工作流来解决计算机视觉问题。

但是我们不使用内置的PyTorch数据集,而是使用我们自己的披萨、牛排和寿司图像数据集。

我们的目标是加载这些图像,然后建立一个模型来训练和预测它们。

更加详细的,我们将讨论下面一些内容:

主题 内容
0 导入PyTorch并设置与设备无关的代码 让我们加载PyTorch,然后按照最佳实践将代码设置为与设备无关。
1 获得数据 我们将使用我们自己定制的披萨、牛排和寿司图像数据集。
2 与数据融为一体(数据准备) 在任何新的机器学习问题的开始,理解你正在处理的数据是至关重要的。在这里,我们将采取一些步骤来弄清楚我们拥有哪些数据。
3 转换数据 通常,你得到的数据不会100%准备好与机器学习模型一起使用,在这里我们将看看我们可以采取的一些步骤来转换我们的图像,以便它们准备好与模型一起使用。
4 使用ImageFolder加载数据(选项1) PyTorch为常见类型的数据提供了许多内置的数据加载函数。如果我们的图像是标准的图像分类格式,ImageFolder是有用的。
5 使用自定义数据集加载图像数据 如果PyTorch没有内置函数来加载数据呢?这是我们可以构建torch.utils.data.Dataset的自定义子类的地方。
6 其他形式的转换(数据增强) 数据增强是扩展训练数据多样性的常用技术。在这里,我们将探索火炬视觉的一些内置数据增强功能。
7 Model 0:没有数据增强的TinyVGG 到这个阶段,我们已经准备好了数据,让我们建立一个能够拟合数据的模型。我们还将创建一些训练和测试函数来训练和评估我们的模型。
8 探索损失曲线 损失曲线是观察你的模型如何训练/改进的好方法。它们也是一种很好的方法来判断你的模型是过拟合还是欠拟合。
9 Model 1:带数据增强功能的TinyVGG 到目前为止,我们已经尝试了一个没有数据增强的模型?
10 比较模型结果 让我们比较不同模型的损失曲线,看看哪个表现更好,并讨论一些改进性能的选项。
11 对自定义图像进行预测 我们的模型是在披萨、牛排和寿司图像的数据集上训练的。在本节中,我们将介绍如何使用我们训练好的模型来预测现有数据集之外的图像。

因为篇幅的原因,本篇将涵盖0至6章节的内容,剩下的部分,将在下篇中介绍。

好,下面进入正文:

0 导入PyTorch并设置与设备无关的代码

import torch
from torch import nn

# Note: this notebook requires torch >= 1.10.0
print(torch.__version__)

输出为:

2.7.0+cu118

现在让我们遵循最佳实践并设置与设备无关的代码。

# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

输出为:

cuda

1 获取数据

首先,我们需要一些数据。

和任何优秀的烹饪节目一样,一些数据已经为我们准备好了。

我们从小事做起。

因为我们还没有打算训练最大的模型或使用最大的数据集。

机器学习是一个迭代的过程,从小处开始,让一些东西工作,并在必要时增加。

我们将要使用的数据是Food101 dataset.数据集的一个子集。

Food101是受欢迎的计算机视觉基准,它包含101种不同食物的1000张图像,总计101,000张图像(75,750张训练图像和25,250张测试图像)。

你能想出101种不同的食物吗?

你能想出一个计算机程序来给101种食物分类吗?

我能。

一个机器学习模型!

具体来说,PyTorch计算机视觉模型,就像我们在上一篇中介绍的那样。

我们将从3种食物开始,而不是101种食物:披萨、牛排和寿司。

而不是每个类1000个图像,我们将从随机的10%开始(从小处开始,必要时增加)。

如果您想了解数据的来源,可以查看以下资源:

  • 原始Food101数据集和论文网站Food101 dataset and paper website.。

  • torchvision.datasets.Food101 -我为这篇文章下载的数据版本。

  • data/pizza_steak_sushi.zip -来自Food101的披萨、牛排和寿司图片的压缩档案,使用上面链接的笔记本创建。

让我们编写一些代码来从GitHub下载格式化的数据。

[!TIP]

注意:我们将要使用的数据集已经按照我们想要使用它的目的进行了预格式化。然而,无论您正在处理什么问题,您通常都必须格式化自己的数据集。这是机器学习领域的常规做法。

import requests
import zipfile
from pathlib import Path

# Setup path to data folder
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it... 
if image_path.is_dir():
    print(f"{
     image_path} directory exists.")
else:
    print(f"Did not find {
     image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)
    
    # Download pizza, steak, sushi data
    with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
        request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        print("Downloading pizza, steak, sushi data...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
        print("Unzipping pizza, steak, sushi data...") 
        zip_ref.extractall(image_path)

输出为:

Did not find data\pizza_steak_sushi directory, creating one...
Downloading pizza, steak, sushi data...
Unzipping pizza, steak, sushi data...

如果,你已经在系统里面有这些数据,输出就会变成下面这种形式。

data\pizza_steak_sushi directory exists.

2 与数据融为一体(数据准备)

数据集下载!

是时候与它融为一体了。

这是构建模型之前的另一个重要步骤。

什么是检查数据并与之融为一体?

在开始一个项目或构建任何类型的模型之前,了解你正在使用的数据是很重要的。

在我们的例子中,我们有标准图像分类格式的披萨、牛排和寿司图像。

图像分类格式在单独的目录中包含单独的图像类,标题为特定的类名。

例如,pizza的所有图像都包含在pizza/目录中。

这种格式在许多不同的图像分类基准中都很流行,包括ImageNet(最流行的计算机视觉基准数据集)。

您可以在下面看到一个存储格式的示例,图像编号是任意的。

pizza_steak_sushi/ <- overall dataset folder
    train/ <- training images
        pizza/ <- class name as folder name
            image01.jpeg
            image02.jpeg
            ...
        steak/
            image24.jpeg
            image25.jpeg
            ...
        sushi/
            image37.jpeg
            ...
    test/ <- testing images
        pizza/
            image101.jpeg
            image102.jpeg
            ...
        steak/
            image154.jpeg
            image155.jpeg
            ...
        sushi/
            image167.jpeg
            ...

我们的目标是将这个数据存储结构转化为PyTorch可用的数据集。

[!TIP]

注意:所处理的数据结构将根据所处理的问题而有所不同。但前提仍然是:与数据融为一体,然后找到一种最好的方法将其转换为与PyTorch兼容的数据集。

我们可以通过编写一个小的辅助函数来遍历每个子目录并计算存在的文件,从而检查数据目录中的内容。

为此,我们将使用Python内置的os.walk()

import os
def walk_through_dir(dir_path):
  """
  Walks through dir_path returning its contents.
  Args:
    dir_path (str or pathlib.Path): target directory
  
  Returns:
    A print out of:
      number of subdiretories in dir_path
      number of images (files) in each subdirectory
      name of each subdirectory
  """
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {
     len(dirnames)} directories and {
     len(filenames)} images in '{
     dirpath}'.")
    
walk_through_dir(image_path)

输出为:

data\pizza_steak_sushi directory exists.
There are 2 directories and 0 images in 'data\pizza_steak_sushi'.
There are 3 directories and 0 images in 'data\pizza_steak_sushi\test'.
There are 0 directories and 25 images in 'data\pizza_steak_sushi\test\pizza'.
There are 0 directories and 19 images in 'data\pizza_steak_sushi\test\steak'.
There are 0 directories and 31 images in 'data\pizza_steak_sushi\test\sushi'.
There are 3 directories and 0 images in 'data\pizza_steak_sushi\train'.
There are 0 directories and 78 images in 'data\pizza_steak_sushi\train\pizza'.
There are 0 directories and 75 images in 'data\pizza_steak_sushi\train\steak'.
There are 0 directories and 72 images in 'data\pizza_steak_sushi\train\sushi'.

太好了!

看起来每个训练班大约有75张图片,每个测试班有25张图片。

这应该足够开始了。

请记住,这些图像是原始Food101数据集的子集。

您可以在数据创建笔记本中查看它们是如何创建的。

同时,让我们设置训练和测试路径。

# Setup train and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"

print(train_dir, test_dir   )

输出为:

data\pizza_steak_sushi directory exists.
data\pizza_steak_sushi\train data\pizza_steak_sushi\test

2.1 形象化

好了,我们已经看到了我们的目录结构是如何格式化的。

现在本着数据探索者的精神,是时候可视化、可视化、可视化了!

让我们写一些代码:

  • 使用pathlib.Path.glob() 获取所有图像路径,以查找所有以.jpg结尾的文件。

  • 使用Python的random.choice() 选择一个随机的图像路径。

  • 使用pathlib.Path.parent.stem获取图像类名。

  • 由于我们正在处理图像,因此我们将使用PIL. image .open() (PIL代表Python image Library)打开随机图像路径。

  • 然后我们将显示图像并打印一些元数据。

import random
from PIL import Image

# Set seed
random.seed(42) # <- try changing this and see what happens

# 1. Get all image paths (* means "any combination")
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. Get random image path
random_image_path = random.choice(image_path_list
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

金沙阳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值