PyTorch笔记(2)

本文详细介绍了PyTorch中Dataset和DataLoader类在处理和加载数据集,特别是神经网络训练过程中的重要性。Dataset用于定义数据访问和预处理,而DataLoader提供了多线程加载、批处理和数据打乱功能。此外,还涵盖了TensorBoard的使用,展示了如何记录和可视化训练过程中的指标。
摘要由CSDN通过智能技术生成

加载数据

在PyTorch中,'Dataset'和'DataLoader'是两个关键的类,用于处理和加载数据,特别是在训练神经网络时。

1.Dataset数据集

Dataset是一个抽象类,用于表示一个数据集。它允许你定义如何访问数据以及在数据集上进行预处理。为了使用Dataset,你需要继承该类并实现两个主要方法:

a.__len__(self):返回数据集的总样本数。

b.__getitem__(self, index): 根据给定的索引返回对应位置的数据样本。

以下是一个简单示例:

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        # 进行必要的预处理
        # 返回单个数据样本
        return sample

 然后可以创建一个CustomDataset的实例,并传递给DataLoader以便进行批量加载和处理。

并且,在Dataset__getitem__方法中,你可以包含与数据相关联的标签,并在返回的元组中包括这些标签。这样,当使用DataLoader加载数据时,你可以同时获取输入数据和对应的标签。

以下是一个简单的示例,假设你的数据集包含图像和相应的类别标签:

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        label = self.labels[index]
        # 进行必要的预处理
        # 返回包含数据样本和标签的元组
        return sample, label

然后就可以像之前一样使用DataLoader来加载数据:

from torch.utils.data import DataLoader

# 假设有图像数据data和对应的标签labels
custom_dataset = CustomDataset(data, labels)

# 创建一个DataLoader,将CustomDataset传递给它
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True)

在迭代data_loader时,每个批次将包含图像数据和对应的标签,可以通过解包批次来分别访问输入数据和标签:

for batch_data, batch_labels in data_loader:
    # 在这里处理批量的数据和标签
    # batch_data 包含图像数据
    # batch_labels 包含对应的标签
    pass

2.DataLoader

 DataLoader是一个用于加载数据的工具类。它封装了一个Dataset实例,并提供多线程数据加载、数据打乱(shuffle)、批量大小(batch size)等功能。通过使用DataLoader,可以方便的迭代整个数据集,同时充分利用硬件资源加速数据加载。

下面是一个简单的使用示例:

from torch.utils.data import DataLoader

# 创建一个CustomDataset的实例
custom_dataset = CustomDataset(data)

# 创建一个DataLoader,将CustomDataset传递给它
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True)

 在这个例子中,batch_size表示每个批次的样本数量,shuffle=True表示在每个epoch之前对数据进行打乱。可以通过迭代data_loader来获取批量的训练数据。

并且,DataLoader允许定义自定义的数据转换函数,从而在每个批次加载时对数据进行不同的形式或预处理。这通常在创建DataLoader实例时使用collate_fn参数来完成。

collate_fn是一个用户自定义的函数,用于处理单个批次的数据。它接受一个包含样本的列表,并返回一个批次的张量或其他数据结构。通过自己定义的collate_fn,可以对每个批次的数据进行个性化的转换。

下面是一个简单的示例,假设希望对图像进行归一化,并将标签转换为张量:

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

class CustomDataset(Dataset):
    # 假设 CustomDataset 中包含图像数据和对应的标签

# 定义数据预处理的转换
data_transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像数据转为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 创建 CustomDataset 实例,传入数据和标签
custom_dataset = CustomDataset(data, labels)

# 创建 DataLoader,并指定 collate_fn 为自定义的转换函数
data_loader = DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True, collate_fn=custom_collate_fn)

# 自定义的 collate_fn 函数
def custom_collate_fn(batch):
    # batch 包含一个批次的样本,每个样本是 (sample, label) 的元组

    # 将样本和标签分别提取出来
    samples, labels = zip(*batch)

    # 对图像数据进行转换
    transformed_samples = [data_transform(sample) for sample in samples]

    # 将标签转换为张量
    tensor_labels = torch.tensor(labels)

    # 返回包含转换后数据的元组
    return transformed_samples, tensor_labels

在这个例子中,custom_collate_fn函数接收一个包含样本的列表,其中每个样本是一个元组’(sample,label)‘。它对图像进行了转换,将图像数据转换为张量并行归一化,将标签转换为张量。最后,它返回一个包含转换后数据的元组'(transformed_samples,tensor_labels)'。

补充:

Tensorboard的使用

PyTorch的内置TensorBoard支持是通过torch.utils.tensorboard模块提供的。其基本步骤为:

1.导入库:

from torch.utils.tensorboard import SummaryWriter

2.创建SummaryWriter

writer = SummaryWriter('logs')  # 'logs'是日志文件保存的目录

在这里logs是一个目录路径,用于存储TensorBoard的日志文件。

3.在训练过程中记录数据

在训练循环中,你可以使用add_scalar方法记录标量数据,例如损失和准确率。还可以使用add_histogram方法记录模型的权重和梯度直方图,使用add_image方法记录图像等。 

for epoch in range(num_epochs):
    # 训练代码...

    # 在每个epoch结束时记录损失
    writer.add_scalar('Loss', loss.item(), global_step=global_step)

    # 在每个epoch结束时记录模型的权重直方图
    for name, param in model.named_parameters():
        writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step=global_step)

    global_step += 1

4.启动 TensorBoard

在终端输入以下指令:

tensorboard --logdir=logs

logdir=日志文件所在文件夹的名称 

 这将在本地主机的6006端口启动TensorBoard服务器(若想在其他端口启动,在终端输入:

tensorboard --logdir=logs --port=端口名

 即可)。然后,可以通过浏览器访问http://localhost:6006来查看可视化结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值