PyTorch数据集加载相关类介绍

torch.utils.data

torch.utils.data

  • torch.utils.data.Dataset
  • torch.utils.data.TensorDataset

一. 数据集相关类介绍

  1. class torch.utils.data.Dataset
    这是 Dataset 的抽象类。所有其它的数据集都应该是该类的子类。所有的子类应该实现__len__和__getitem__方法,提供数据集的大小和整数索引,范围从0到len(self)。
  2. class torch.utils.data.TensorDataset(data_tensor, target_tensor)
    torch.utils.data.Dataset 的子类
    包装数据和目标张量的数据集。
    通过沿着第一个维度索引两个张量来恢复每个样本。
    参数
    * data_tensor(Tensor) - 包含样本数据。
    * target_tensor(Tensor) - 包含样本目标(标签)
from torch.utils import data

# 合成数据和标签
X = torch.rand(20, 10)
y = X @ torch.rand(10, 1)

# 1. 使用 data.TensorDataset 类来实例化数据集
train_data1 = data.TensorDataset(X, y)

>>> train1[0]
>(tensor([0.3308, 0.5128, 0.7089, 0.5980, 0.5363, 0.4596, 0.7110, 0.9807, 0.8420,
         0.2185]),
 tensor([3.1887]))
# 2. 自定义数据集
class Mydata(data.Dataset):
	"""
	继承data.Dataset类来自定义数据集类,需要实现__len__和__getitem__方法。
	"""
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return (self.data[index], self.label[index])
train_data2 = Mydata(X, y)

>>>train_data2[0]
>(tensor([0.3308, 0.5128, 0.7089, 0.5980, 0.5363, 0.4596, 0.7110, 0.9807, 0.8420,
         0.2185]),
 tensor([3.1887]))

二. 数据采样器相关类

  1. class torch.utils.data.sampler.Sampler(data_source)
    所有采样器的基础类。
    每个采样器子类必须提供一个__iter__方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的__len__方法。
  2. class torch.utils.data.sampler.SequentialSampler(data_source)
    样本元素顺序排列,始终以相同的顺序。
    参数
    * data_source(Dataset) - 采样的数据集。
  3. class torch.utils.data.sampler.RandomSampler(data_source, replacement=True, num_samples=[int])
    样本元素随机,没有替换。
    参数
    * data_source(Dataset) - 采样的数据集。
    * replacement - 是否放回抽样
    * num_samples - 抽样数量
  4. class torch.utils.data.sampler.SubsetRandomSampler(indices)
    样本元素从指定的索引列表中随机抽取,没有替换。
    参数
    * indices (list) – 索引的列表
  5. class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)
    样本元素来自于[0,…,len(weights)-1],给定概率(weights)。
    参数
    * weights (list) – 权重列表。没必要加起来为1 - num_samples (int) – 抽样数量
  6. class torch.utils.data.sampler.BatchSampler(Sampler[List[int]])
    以另一个sampler为参数,成批返回索引。
    参数
    * sampler - 采样器。
    * batch_size - 批次大小。
    * drop_last - 是否丢弃最后的小批次。
from torch.utils import data

# 合成数据和标签
X = torch.rand(20, 10)
y = X @ torch.rand(10, 1)

# 使用 data.TensorDataset 类来实例化数据集
train_data = data.TensorDataset(X, y)

# 1. class SequentialSampler(Sampler[int])类
sequential_sampler = SequentialSampler(train_data)

>>>for e in sequential_sampler:
>>>    print(e, end=" ")
>0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19

# 2. class RandomSampler(Sampler[int])类
random_sampler = RandomSampler(train_data,replacement=True,num_samples=10)
>>>for e in random_sampler:
>>>    print(e, end=" ")
>12 9 6 7 17 1 0 9 1 4 

# 3. class SubsetRandomSampler(Sampler[int])类
subset_random_sampler = SubsetRandomSampler([x for x in range(10)])
>>>for e in subset_random_sampler :
>>>    print(e, end=" ")
>3 8 7 9 1 4 5 6 0 2 

# 4. class WeightedRandomSampler(Sampler[int])类
weighted_random_sampler = WeightedRandomSampler([50, 10, 5, 1], 10, replacement=True)
>>>for e in weighted_random_sampler:
>>>    print(e, end=" ")
>0 0 0 1 0 0 1 0 2 1 

# 5. class BatchSampler(Sampler[List[int]])类
batch_sampler = BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)
>>>print(list(batch_sampler))
>[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

三. 数据加载器相关类

  1. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_works=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=Fasle)
    数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
    参数
    * dataset(Dataset) - 加载数据的数据集。
    * batch_size(int, optional) - 每个batch加载多少个样本(默认:1)。
    * shuffle(bool, optional) - 设置为True时会在每个epoch重新打乱数据(默认为:False)。
    * sampler(Sampler, optional) - 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
    * num_workers(int, optional) - 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)。
    * drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)。
from torch.utils import data

# 合成数据和标签
X = torch.rand(20, 10)
y = X @ torch.rand(10, 1)

# 使用 data.TensorDataset 类来实例化数据集
train_data = data.TensorDataset(X, y)

# 1. class SequentialSampler(Sampler[int])类
sequential_sampler = SequentialSampler(train_data)

# 定义下载器
loader = data.Dataloader(train_data, sampler=sequential_sampler, batch_size=5)
>>>for e in loader:
>>>    print(e, end" ")
>[tensor([[0.2643, 0.9046, 0.4118, 0.1491, 0.8853, 0.3904, 0.2309, 0.5114, 0.5720,
         0.6731],
        [0.3611, 0.5206, 0.5185, 0.9278, 0.2859, 0.7677, 0.6613, 0.7987, 0.0447,
         0.2355],
        [0.8693, 0.2007, 0.5954, 0.2602, 0.7560, 0.8815, 0.7273, 0.9077, 0.1683,
         0.2609],
        [0.2643, 0.9046, 0.4118, 0.1491, 0.8853, 0.3904, 0.2309, 0.5114, 0.5720,
         0.6731],
        [0.2638, 0.7770, 0.3765, 0.0899, 0.2760, 0.2596, 0.9405, 0.5447, 0.0263,
         0.5279]]), tensor([[3.0099],
        [2.5908],
        [3.3519],
        [3.0099],
        [2.0028]])] 
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值