如果你是AI零基础,请关注本专栏,将带你一起飞。
在PyTorch中,通常使用torchvision和torch.utils.data模块实现数据集的加载功能。在此模块中提供了用于加载和预处理常见数据集的工具,同时也支持自定义数据集的加载。
2.1.1 PyTorch加载数据集
在PyTorch程序中,模块torchvision.datasets提供了许多常见的预定义数据集,并提供了简单的API来加载这些数据集。以下是一些常用的数据集加载函数:
- torchvision.datasets.ImageFolder:用于加载图像文件夹数据集,其中每个子文件夹表示一个类别,文件夹中的图像属于该类别。
- torchvision.datasets.CIFAR10和torchvision.datasets.CIFAR100:用于加载CIFAR-10和CIFAR-100数据集,这是两个广泛使用的图像分类数据集。
- torchvision.datasets.MNIST:用于加载MNIST手写数字数据集,其中包含了大量的手写数字图像及其对应的标签。
- torchvision.datasets.ImageNet:用于加载ImageNet数据集,这是一个庞大的图像分类数据集,包含数百万个图像和数千个类别。
- torchvision.datasets.VOCDetection:用于加载PASCAL VOC数据集,这是一个常用的目标检测数据集,包含了图像及其对应的物体边界框和类别标签。
上述数据集加载函数通常具有类似的参数,如root(数据集的根目录)、train(是否加载训练集)、download(是否下载数据集)、transform(数据预处理操作)等。此外,还可以使用torch.utils.data.DataLoader函数来创建一个数据加载器,用于批量加载数据。数据加载器可以方便地对数据进行批处理、洗牌、并行加载等操作,以提高数据加载的效率和灵活性。例如下面是一个简单的例子,展示了使用torchvision.datasets和torch.utils.data.DataLoader加载数据集的过程。
实例2-1:加载CIFAR-10数据集
源码路径:daima\2\jia.py
实例文件jia.py的具体实现代码如下所示。
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
# 定义数据预处理操作
transform = transforms.Compose([
transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# 创建CIFAR-10数据集实例
train_dataset = CIFAR10(root='data/', train=True, download=True, transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 遍历数据加载器
for images, labels in train_loader:
# 在此处进行模型训练或其他操作
pass
在上述代码中,首先定义了一个transform变量,其中包含了一系列预处理操作。然后,使用CIFAR10函数创建一个CIFAR-10数据集实例,指定了数据集的根目录、训练集标志、下载标志和预处理操作。最后,使用DataLoader函数创建一个数据加载器,指定了数据集实例和批量大小等参数。通过这种方式,我们可以方便地加载数据集,并使用数据加载器进行高效的批处理数据加载。
2.1.2 TensorFlow加载数据集
从Tensorflow 2.0开始,提供了专门用于实现数据输入的接口tf.data.Dataset,能够以快速且可扩展的方式加载和预处理数据,帮助开发者高效的实现数据的读入、打乱(shuffle)、增强(augment)等功能。例如在下面的实例文件中,演示了使用tf.data.Dataset加载MNIST 手写数字数据集的的过程。
实例2-2:使用tf.data.Dataset加载MNIST 手写数字数据集(源码路径:daima\2\yu.py)
实例文件tjia.py的具体实现代码如下所示。
import tensorflow as tf
# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# 将数据转换为张量并标准化
train_images = train_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
# 将标签转换为独热编码
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
# 创建训练集和测试集的 Dataset 对象
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
train_dataset = train_dataset.shuffle(buffer_size=60000).batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size)
# 打印数据集信息
print("训练集样本数:", len(train_images))
print("测试集样本数:", len(test_images))
print("图像形状:", train_images.shape[1:])
print("标签类别数:", train_labels.shape[1])
上述代码加载了 MNIST 数据集,将图像数据转换为张量并进行了标准化。然后,创建了训练集和测试集的 Dataset 对象,并显示了一些有关数据集的基本信息。请注意,这仅仅是加载数据集的代码,不涉及模型构建和训练。执行后会输出:
训练集样本数: 60000
测试集样本数: 10000
图像形状: (28, 28, 1)