PyTorch数据处理:torch.utils.data模块的7个核心函数详解

本文将深入介绍PyTorch中torch.utils.data模块的7个核心函数,这些工具可以帮助你更好地管理和操作数据。

1、Dataset类

Dataset类是PyTorch数据处理的基础。通过继承这个类可以创建自定义的数据集,适应各种类型的数据,如图像、文本或时间序列数据。

要创建自定义数据集,需要实现两个关键方法:

  • __len__方法:返回数据集的大小

  • __getitem__方法:根据给定的索引检索样本

这种灵活性使得Dataset类能够处理各种数据格式和来源。

代码示例:

 import torch
 from torch.utils.data import Dataset
 
 class CustomDataset(Dataset):
     def __init__(self, data, labels):
         self.data = data
         self.labels = labels
     
     def __len__(self):
         return len(self.data)
     
     def __getitem__(self, idx):
         return self.data[idx], self.labels[idx]
 
 # 创建一个简单的数据集
 data = torch.randn(100, 5)  # 100个样本,每个样本5个特征
 labels = torch.randint(0, 2, (100,))  # 二分类标签
 
 dataset = CustomDataset(data, labels)
 print(f"数据集大小: {len(dataset)}")
 print(f"第一个样本: {dataset[0]}")

2、DataLoader

DataLoader是一个极其重要的工具,它封装了数据集并提供了一个可迭代对象。它简化了批量加载、数据shuffling和并行数据处理等操作,是训练和评估模型时高效输入数据的关键。

DataLoader的主要功能包括:

  • 批量加载数据

  • 自动shuffling数据

  • 多进程数据加载以提高效率

  • 自定义数据采样策略

代码示例:

from torch.utils.data import DataLoader
 
 # 使用之前创建的dataset
 batch_size = 16
 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
 
 for batch_data, batch_labels in dataloader:
     print(f"批次数据形状: {batch_data.shape}")
     print(f"批次标签形状: {batch_labels.shape}")
     break  # 只打印第一个批次

3. Subset

Subset可以从一个大型数据集中创建一个较小的、特定的子集。这在以下场景中特别有用:

  • 使用数据子集进行实验

  • 将数据集分割为训练集、验证集和测试集

通过指定索引,可以轻松创建所需的数据子集。

from torch.utils.data import Subset
import numpy as np
 
# 创建一个子集,包含原始数据集的前20%的数据
dataset_size = len(dataset)
subset_size = int(0.2 * dataset_size)
subset_indices = np.random.choice(dataset_size, subset_size, replace=False)
 
subset = Subset(dataset, subset_indices)
print(f"子集大小: {len(subset)}")
 
# 使用子集创建新的DataLoader
subset_loader = DataLoader(subset, batch_size=8, shuffle=True)

4、ConcatDataset

ConcatDataset用于将多个数据集组合成一个单一的数据集。当有多个需要一起使用的数据集时,这个工具非常有用。它可以:

  • 合并来自不同来源的数据

  • 创建更大、更多样化的训练集

代码示例:

 from torch.utils.data import ConcatDataset
 
 # 创建两个简单的数据集
 dataset1 = CustomDataset(torch.randn(50, 5), torch.randint(0, 2, (50,)))
 dataset2 = CustomDataset(torch.randn(30, 5), torch.randint(0, 2, (30,)))
 
 # 合并数据集
 combined_dataset = ConcatDataset([dataset1, dataset2])
 print(f"合并后的数据集大小: {len(combined_dataset)}")
 
 # 使用合并后的数据集创建DataLoader
 combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True)

5、TensorDataset

当数据已经以张量形式存在时,TensorDataset非常有用。它将张量包装成一个数据集对象,使得处理预处理的特征和标签变得简单。

TensorDataset的主要优势在于:

  • 直接使用张量数据

  • 简化了已经预处理数据的使用流程

代码示例:

 from torch.utils.data import TensorDataset
 
 # 创建特征和标签张量
 features = torch.randn(1000, 10)  # 1000个样本,每个样本10个特征
 labels = torch.randint(0, 5, (1000,))  # 5分类问题
 
 # 创建TensorDataset
 tensor_dataset = TensorDataset(features, labels)
 
 # 使用TensorDataset创建DataLoader
 tensor_loader = DataLoader(tensor_dataset, batch_size=32, shuffle=True)
 
 for batch_features, batch_labels in tensor_loader:
     print(f"特征形状: {batch_features.shape}, 标签形状: {batch_labels.shape}")
     break

6、RandomSampler

RandomSampler用于从数据集中随机采样元素。在使用随机梯度下降(SGD)等需要随机采样的训练方法时,这个工具尤为重要。它可以帮助:

  • 增加训练的随机性

  • 减少模型过拟合的风险

代码示例:

 from torch.utils.data import RandomSampler
 
 # 使用之前创建的dataset
 random_sampler = RandomSampler(dataset, replacement=True, num_samples=50)
 
 # 使用RandomSampler创建DataLoader
 random_loader = DataLoader(dataset, batch_size=10, sampler=random_sampler)
 
 for batch_data, batch_labels in random_loader:
     print(f"随机采样批次大小: {batch_data.shape[0]}")
     break

7、WeightedRandomSampler

WeightedRandomSampler基于指定的概率(权重)进行有放回采样。这在处理不平衡数据集时特别有用,因为它可以:

  • 更频繁地采样少数类

  • 平衡类别分布,提高模型对少数类的敏感度

代码示例:

 from torch.utils.data import WeightedRandomSampler
 import torch.nn.functional as F
 
 # 假设我们有一个不平衡的数据集
 imbalanced_labels = torch.tensor([0, 0, 0, 0, 1, 1, 2])
 # torch.unique()的功能类似于数学中的集合,就是挑出tensor中的独立不重复元素。
 class_sample_count = torch.tensor([(imbalanced_labels == t).sum() for t in torch.unique(imbalanced_labels, sorted=True)])
 weight = 1. / class_sample_count.float()
 samples_weight = torch.tensor([weight[t] for t in imbalanced_labels])
 
 # 创建WeightedRandomSampler
 weighted_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
 
 # 创建一个简单的数据集
 imbalanced_dataset = TensorDataset(torch.randn(7, 5), imbalanced_labels)
 
 # 使用WeightedRandomSampler创建DataLoader
 weighted_loader = DataLoader(imbalanced_dataset, batch_size=3, sampler=weighted_sampler)
 
 for batch_data, batch_labels in weighted_loader:
     print(f"采样的标签: {batch_labels}")
     break

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ghx3110

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值