pytorch的Dataset和Dataloader的简单使用

导入方法

from torch.utils.data import Dataset, DataLoader

作用

Dataset类:可以根据id索引出单个的数据,还可以进行一些预处理。
Dataloader类:将数据集进行打包成迭代器。

基础使用方法

Dataset:需要实现三个函数__init__,len,getitem,可以自己添加一些预处理的函数,比如划分训练集、验证集。

from torch.utils.data import Dataset, DataLoader, random_split

class MyDataset(Dataset)
	def __init__(self, data):#如果传的是文件地址也可以,比如cvs文件
		#df = read_csv(path);  
		#对csv文件进行处理
		self.data = data
	def __len__(self):    ##获取数据集大小
		return len(self.data)
	def __getitem__(self, id)
		return self.data[id]
    def get_splits(self, n_test=0.3):
       # determine sizes
       test_size = round(n_test * len(self.X))
       train_size = len(self.X) - test_size
       # calculate the split
       return random_split(self.data, [train_size, test_size])
		

Dataloader类:将数据集按batch进行打包

常见的参数说明:
dataset:传入训练集或者验证集(Dataset对象)。
batch_size:多少个数据组成一个整体,越大对内存要求更高。
shuffle:训练集一般设为True,验证集为False。
其余很多参数不常用,比如设置采样的规则、数据集不够分时最后一个batch丢不丢等等。
所以一般都是这样写:

train, test = dataset.get_splits()
train_dl = DataLoader(train, batch_size=32, shuffle=True)
test_dl = DataLoader(test, batch_size=32, shuffle=False)
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值