两大指南函数
-
dir()
功能:返回包含查询对象的所有属性和方法名称的列表
dir(torch) # 输出: torch中所包含的全部方法 dir(torch.cuda.is_available) # 输出:该函数所有属性
-
help()
功能:查看函数或模块用途的详细说明
使用dir()和help()查询函数仅输入函数名,不输入括号。
数据加载和预处理
-
Dataset类
Dataset是一个抽象类,它的主要作用是封装数据集。你可以将Dataset看作是数据的集合,
定义了如何获取数据集中的单个样本。Dataset类需要实现__init__,__len__和__getitem__三个方法。
一些内置的Dataset类,比如
ImageFolder
、CIFAR10
等,它们已经实现了这些方法,可以直接使用。 -
DataLoader类
DataLoader是Dataset的包装器,它的主要作用是提供批量加载数据的能力。
使用DataLoader时,需要指定以下参数:
- dataset:你创建的Dataset实例。
- batch_size:每个批次的样本数量。
- shuffle:是否在每个epoch开始时打乱数据。
- num_workers:加载数据时使用的进程数量。
from torch.utils.data import DataLoader train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
iter函数将train_dataloader变成一个迭代器,使用next函数可以以此从迭代器中生成一个一个的批次
train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0]
可视化中squeeze()的作用
squeeze函数用于去除张量中所有大小为1的维度
- 简化可视化:在可视化图像时,我们通常只需要[channels, height, width]的三维张量。squeeze函数可以帮助去除批次维度(batch_size),使得我们可以直接处理单个图像。
- 去除冗余维度:如果图像数据在加载或预处理过程中被错误地增加了额外的维度,squeeze可以帮助去除这些不必要的维度。