Pytorch Dataset类
数据集的类
- torch.utils.data.Dataset
- 实现两个方法
- _getitem_(self,index)
- 获取索引对应位置的一条数据
- _len_(self)
- 返回数据的总数量
import torch
import torch.utils.data as Data
class myDataSet(Data.Dataset):
def __init__(self):
self.f=open('data','r').readlines()
def __getitem__(self, item):
return self.f[item]
def __len__(self):
return len(self.f)
dataset=myDataSet()
print('len = ',len(dataset))
print('dataset[1] = ',dataset[0])
print('dataset[2] = ',dataset[1])
len = 2
dataset[1] = 1 2 3 4 5 6 7 8 9 10
dataset[2] = 3.5 4.7 7.7 8.3 10.9 14 14.8 17.2 18.9 21.3
数据加载器类
- 批处理数据(batch)
- 打乱数据(shuffle=True)
- 使用多线程相乘并行加载数据(num_workers)
- 删除mod(batch)多余的元素(drop_last=True)
DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)
enumerate()返回遍历的序号
data_loader=Data.DataLoader(dataset=dataset,batch_size=1,shuffle=True)
print('data_loader = ',data_loader)
print('len(data_loader) = ',len(data_loader))
for index,i in enumerate(data_loader):
print('这是第%d个元素'%index)
print('i = ',i)
print('i[0].strip() = ',i[0].strip())
data_loader = <torch.utils.data.dataloader.DataLoader object at 0x0000022132F02E48>
len(data_loader) = 2
这是第0个元素
i = ['1 2 3 4 5 6 7 8 9 10\n']
i[0].strip() = 1 2 3 4 5 6 7 8 9 10
这是第1个元素
i = ['3.5 4.7 7.7 8.3 10.9 14 14.8 17.2 18.9 21.3']
i[0].strip() = 3.5 4.7 7.7 8.3 10.9 14 14.8 17.2 18.9 21.3
data_loader也支持len方法
- 向上取整(math.ceil())
附录:
data(文件)
1 2 3 4 5 6 7 8 9 10
3.5 4.7 7.7 8.3 10.9 14 14.8 17.2 18.9 21.3