加载数据
在PyTorch中,'Dataset'和'DataLoader'
是两个关键的类,用于处理和加载数据,特别是在训练神经网络时。
1.Dataset数据集
Dataset是一个抽象类,用于表示一个数据集。它允许你定义如何访问数据以及在数据集上进行预处理。为了使用Dataset,你需要继承该类并实现两个主要方法:
a.__len__(self):返回数据集的总样本数。
b.__getitem__(self, index)
: 根据给定的索引返回对应位置的数据样本。
以下是一个简单示例:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# 进行必要的预处理
# 返回单个数据样本
return sample
然后可以创建一个CustomDataset的实例,并传递给DataLoader以便进行批量加载和处理。
并且,在Dataset
的__getitem__
方法中,你可以包含与数据相关联的标签,并在返回的元组中包括这些标签。这样,当使用DataLoader
加载数据时,你可以同时获取输入数据和对应的标签。
以下是一个简单的示例,假设你的数据集包含图像和相应的类别标签:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
label = self.labels[index]
# 进行必要的预处理
# 返回包含数据样本和标签的元组
return sample, label
然后就可以像之前一样使用DataLoader来加载数据:
from torch.utils.data import DataLoader
# 假设有图像数据data和对应的标签labels
custom_dataset = CustomDataset(data, labels)
# 创建一个DataLoader,将CustomDataset传递给它
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True)
在迭代data_loader时,每个批次将包含图像数据和对应的标签,可以通过解包批次来分别访问输入数据和标签:
for batch_data, batch_labels in data_loader:
# 在这里处理批量的数据和标签
# batch_data 包含图像数据
# batch_labels 包含对应的标签
pass
2.DataLoader
DataLoader是一个用于加载数据的工具类。它封装了一个Dataset实例,并提供多线程数据加载、数据打乱(shuffle)、批量大小(batch size)等功能。通过使用DataLoader,可以方便的迭代整个数据集,同时充分利用硬件资源加速数据加载。
下面是一个简单的使用示例:
from torch.utils.data import DataLoader
# 创建一个CustomDataset的实例
custom_dataset = CustomDataset(data)
# 创建一个DataLoader,将CustomDataset传递给它
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True)
在这个例子中,batch_size表示每个批次的样本数量,shuffle=True表示在每个epoch之前对数据进行打乱。可以通过迭代data_loader来获取批量的训练数据。
并且,DataLoader允许定义自定义的数据转换函数,从而在每个批次加载时对数据进行不同的形式或预处理。这通常在创建DataLoader实例时使用collate_fn参数来完成。
collate_fn是一个用户自定义的函数,用于处理单个批次的数据。它接受一个包含样本的列表,并返回一个批次的张量或其他数据结构。通过自己定义的collate_fn,可以对每个批次的数据进行个性化的转换。
下面是一个简单的示例,假设希望对图像进行归一化,并将标签转换为张量:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
class CustomDataset(Dataset):
# 假设 CustomDataset 中包含图像数据和对应的标签
# 定义数据预处理的转换
data_transform = transforms.Compose([
transforms.ToTensor(), # 将图像数据转为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
# 创建 CustomDataset 实例,传入数据和标签
custom_dataset = CustomDataset(data, labels)
# 创建 DataLoader,并指定 collate_fn 为自定义的转换函数
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True, collate_fn=custom_collate_fn)
# 自定义的 collate_fn 函数
def custom_collate_fn(batch):
# batch 包含一个批次的样本,每个样本是 (sample, label) 的元组
# 将样本和标签分别提取出来
samples, labels = zip(*batch)
# 对图像数据进行转换
transformed_samples = [data_transform(sample) for sample in samples]
# 将标签转换为张量
tensor_labels = torch.tensor(labels)
# 返回包含转换后数据的元组
return transformed_samples, tensor_labels
在这个例子中,custom_collate_fn函数接收一个包含样本的列表,其中每个样本是一个元组’(sample,label)‘。它对图像进行了转换,将图像数据转换为张量并行归一化,将标签转换为张量。最后,它返回一个包含转换后数据的元组'(transformed_samples,tensor_labels)'。
补充:
Tensorboard的使用
PyTorch的内置TensorBoard支持是通过torch.utils.tensorboard模块提供的。其基本步骤为:
1.导入库:
from torch.utils.tensorboard import SummaryWriter
2.创建SummaryWriter
writer = SummaryWriter('logs') # 'logs'是日志文件保存的目录
在这里logs是一个目录路径,用于存储TensorBoard的日志文件。
3.在训练过程中记录数据
在训练循环中,你可以使用add_scalar方法记录标量数据,例如损失和准确率。还可以使用add_histogram方法记录模型的权重和梯度直方图,使用add_image方法记录图像等。
for epoch in range(num_epochs):
# 训练代码...
# 在每个epoch结束时记录损失
writer.add_scalar('Loss', loss.item(), global_step=global_step)
# 在每个epoch结束时记录模型的权重直方图
for name, param in model.named_parameters():
writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step=global_step)
global_step += 1
4.启动 TensorBoard
在终端输入以下指令:
tensorboard --logdir=logs
logdir=日志文件所在文件夹的名称
这将在本地主机的6006端口启动TensorBoard服务器(若想在其他端口启动,在终端输入:
tensorboard --logdir=logs --port=端口名
即可)。然后,可以通过浏览器访问http://localhost:6006来查看可视化结果。