(1)加载数据集

如果你是AI零基础,请关注本专栏,将带你一起飞。

在PyTorch中,通常使用torchvision和torch.utils.data模块实现数据集的加载功能。在此模块中提供了用于加载和预处理常见数据集的工具,同时也支持自定义数据集的加载。

2.1.1  PyTorch加载数据集

在PyTorch程序中,模块torchvision.datasets提供了许多常见的预定义数据集,并提供了简单的API来加载这些数据集。以下是一些常用的数据集加载函数:

  1. torchvision.datasets.ImageFolder:用于加载图像文件夹数据集,其中每个子文件夹表示一个类别,文件夹中的图像属于该类别。
  2. torchvision.datasets.CIFAR10和torchvision.datasets.CIFAR100:用于加载CIFAR-10和CIFAR-100数据集,这是两个广泛使用的图像分类数据集。
  3. torchvision.datasets.MNIST:用于加载MNIST手写数字数据集,其中包含了大量的手写数字图像及其对应的标签。
  4. torchvision.datasets.ImageNet:用于加载ImageNet数据集,这是一个庞大的图像分类数据集,包含数百万个图像和数千个类别。
  5. torchvision.datasets.VOCDetection:用于加载PASCAL VOC数据集,这是一个常用的目标检测数据集,包含了图像及其对应的物体边界框和类别标签。

上述数据集加载函数通常具有类似的参数,如root(数据集的根目录)、train(是否加载训练集)、download(是否下载数据集)、transform(数据预处理操作)等。此外,还可以使用torch.utils.data.DataLoader函数来创建一个数据加载器,用于批量加载数据。数据加载器可以方便地对数据进行批处理、洗牌、并行加载等操作,以提高数据加载的效率和灵活性。例如下面是一个简单的例子,展示了使用torchvision.datasets和torch.utils.data.DataLoader加载数据集的过程。

实例2-1:加载CIFAR-10数据集

源码路径:daima\2\jia.py

实例文件jia.py的具体实现代码如下所示。

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),

])

# 创建CIFAR-10数据集实例

train_dataset = CIFAR10(root='data/', train=True, download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 遍历数据加载器
for images, labels in train_loader:
    # 在此处进行模型训练或其他操作
    pass

在上述代码中,首先定义了一个transform变量,其中包含了一系列预处理操作。然后,使用CIFAR10函数创建一个CIFAR-10数据集实例,指定了数据集的根目录、训练集标志、下载标志和预处理操作。最后,使用DataLoader函数创建一个数据加载器,指定了数据集实例和批量大小等参数。通过这种方式,我们可以方便地加载数据集,并使用数据加载器进行高效的批处理数据加载。

2.1.2  TensorFlow加载数据集

Tensorflow 2.0开始,提供了专门用于实现数据输入的接口tf.data.Dataset,能够以快速且可扩展的方式加载和预处理数据,帮助开发者高效的实现数据的读入、打乱(shuffle)、增强(augment)等功能。例如在下面的实例文件中,演示了使用tf.data.Dataset加载MNIST 手写数字数据集的的过程。

实例2-2:使用tf.data.Dataset加载MNIST 手写数字数据集(源码路径:daima\2\yu.py

实例文件tjia.py的具体实现代码如下所示。

import tensorflow as tf

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# 将数据转换为张量并标准化
train_images = train_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
# 将标签转换为独热编码
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
# 创建训练集和测试集的 Dataset 对象
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
train_dataset = train_dataset.shuffle(buffer_size=60000).batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size)
# 打印数据集信息
print("训练集样本数:", len(train_images))
print("测试集样本数:", len(test_images))
print("图像形状:", train_images.shape[1:])
print("标签类别数:", train_labels.shape[1])

上述代码加载了 MNIST 数据集,将图像数据转换为张量并进行了标准化。然后,创建了训练集和测试集的 Dataset 对象,并显示了一些有关数据集的基本信息。请注意,这仅仅是加载数据集的代码,不涉及模型构建和训练。执行后会输出:

训练集样本数: 60000

测试集样本数: 10000

图像形状: (28, 28, 1)

  • 22
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农三叔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值