一、综述
Dataset :对数据进行抽象,将数据包装为Dataset类。
DataLoader:在 Dataset之上对数据进行进一步处理,包括进行乱序处理,获取一个batch_size的数据等。
二、Dataset
在Dataset类中必须重新 getitem()、len()两个方法。
- 创建数据
ss=np.linspace(1,100,100)
np.savetxt("sample_data.txt", ss.reshape(-1,4))
数据格式如下所示:
2. 创建自定义Dataset
import numpy as np
import torch as t
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self):
#使用numy读取数据
txt_data = np.loadtxt('sample_data.txt')
#取数据前三列为x
self._x = t.from_numpy(txt_data[:,:3])
#取数据最后一列为target值
self._y = t.from_numpy(txt_data[:,-1])
#获取数据的长度
self._len = len(txt_data)
def __getitem__(self,item):
#item对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
return self._x[item],self._y[item]
def __len__(self):
#带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回
return self._len
dataset = MyDataSet()
print(len(dataset))
data =next(iter(dataset))
print(data)
三、 DataLoader
关键参数:
- dataset :数据集
- batch_size : 一个批次的大小
- shuffle : 是否乱序处理
- sampler:非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.
- drop_last:如果数据集大小不能整除batch_size的话,是否删除最后一个batch
from torch.utils.data import DataLoader
data = MyDataSet()
dataloader = DataLoader(data,batch_size=4,shuffle=True,drop_last=True,num_workers=0)
for i,data in enumerate(dataloader):
print('batch---->',i+1)
inputs,labels=data
print(inputs)
print(labels)
print("*"*30)
四、random_split
pytorch中 random_split类似于 sklearn中的train_test_split类似的功能,将数据切分为训练集、测试集、验证集。
from torch.utils.data import random_split
all_length =len(dataset)
train_size =int(0.8*all_length)
test_size = all_length - train_size
#切分数据集
train_dataset,test_dataset = random_split(dataset,[train_size,test_size])
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
for i,curr_data in enumerate(train_loader):
print('batch---->',i+1)
inputs,labels=curr_data
print(inputs)
print(labels)
print("*"*30)
```
![在这里插入图片描述](https://img-blog.csdnimg.cn/2021012612065338.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0dhb3dhaGFoYQ==,size_16,color_FFFFFF,t_70)