pytorch中Dataset,TensorDataset和DataLoader用法

1 用法介绍

pytorch中常用类torch.utils.data.Datasettorch.utils.data.TensorDataset对数据进行封装;常用类torch.utils.data.DataLoader对数据进行加载。具体的用法细节如下所示:

1.1 torch.utils.data.Dataset的用法

class Dataset(object):
	def  __getitem__(self, index):
		raise NotImplementError
	def __len__(self):
		raise NotImplementError
	def __add__(self, other):
		return ConcatDataset([self, other])

注:torch.utils.data.Dataset表示一个数据集的抽象类,所有的其它数据集都要以它为父类进行数据封装。Dataset的类函数__getitem____len__必须要被进行重写。

1.2 torch.utils.data.TensorDataset的用法

classtorch.utils.data.TensorDataset(data_tensor, target_tensor)

  • data_tensor : 需要被封装的数据样本
  • target_tensor : 需要被封装的数据标签
class TensorDataset(Dataset):
    # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]
    def __len__(self):
        return self.data_tensor.size(0)

注:torch.utils.data.TensorDataset继承父类torch.utils.data.Dataset,不需要对类TensorDataset的函数进行重写。

1.3 torch.utils.data.DataLoader的用法

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

  • dataset (Dataset): 封装后的数据集。
  • batch_size (python:int,optional)): 每一批加载的样本量,默认值为1。
  • shuffle (bool,optional): 设置为True时,每一个epoch重新打乱数据顺序。
  • sampler (Sampler,optional): 定义在数据集中进行采样的策略,如果被指定,则False必须为shuffle
  • batch_sampler (Sampler,optional): 类似sampler,但是一次返回一批索引。互斥有batch_sizeshufflesamplerdrop_last
  • num_workers (python:int,optional): 多少个子进程用于数据加载。0表示将在主进程中加载数据,默认值为0。
  • collate_fn(callable,optional): 合并样本列表以形成张量的小批量。在从地图样式数据集中使用批量加载时使用。
  • pin_memory (bool,optional): 如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。
  • drop_last (bool,optional): 设置为True,如果数据集大小不能被该批次大小整除则删除最后一个不完整的批次。如果False,数据集的大小不能被批量大小整除,那么最后一个批量将更小,默认值为False
  • timeout (numeric,optional): 如果为正,则为从worker收集批次的超时值。应始终为非负数,默认值为0。
  • worker_init_fn (callable,optional): 如果不是None,则在种子工作之后和数据加载之前,将在每个工作程序子进程上调用此程序,并以工作程序ID作为输入,取值为[0, num_workers - 1]None。

注:torch.utils.data.DataLoader结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时该类可以将数据进行切分,每次抛出一组数据,直至把所有的数据都抛出。

2 代码实例

实例1中数据封装利用的是TensorDataset,数据加载利用的是DataLoader具体代码如下所示:

import torch
import torch.utils.data as Data
BATCH_SIZE = 5 
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)
print(loader)
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

if __name__ == '__main__':
    show_batch()

运行结果为:

实例2中利用Dataset抽象类自定义出了一个子类进行数据封装,数据加载利用的是DataLoader具体代码如下所示:

import numpy as np 
import torch
import torch.utils.data as Data

def generate_dataset(sample_num, class_num, X_shape):
	Label_list = []
	Sample_list = []
	for i in range(sample_num):
		y = np.random.randint(0, class_num)
		Label_list.append(y)
		Sample_list.append(np.random.normal(y, 0.2, X_shape))
	return torch.tensor(Sample_list), torch.tensor(Label_list)

class Normal_Dataset(Data.Dataset):
	def __init__(self, Numpy_Dataset):
		super(Normal_Dataset, self).__init__()
		self.data_tensor = Numpy_Dataset[0]
		self.target_tensor = Numpy_Dataset[1]

	def __getitem__(self, index):
		return self.data_tensor[index], self.target_tensor[index]

	def __len__(self):
		return self.data_tensor.size(0)

if __name__ == '__main__':
	numpy_dataset = generate_dataset(10, 2, 5) 
	Dataset = Normal_Dataset(numpy_dataset)
	DataLoader = Data.DataLoader(
						dataset = Dataset,
						batch_size = 2,
						shuffle = True,
						num_workers = 0,
						)
	for epoch in range(2):
		for step, (batch_x, batch_y) in enumerate(DataLoader):
			print("step: {}, batch_x: {}, batch_y: {}".format(step, batch_x, batch_y))

运行结果为:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值