- 自定义一个数据集
#定义一个数据集
class BulldozerDataset(Dataset):
“”" 数据集演示 “”"
def init(self, csv_file):
“”“实现初始化方法,在初始化的时候将数据读载入”“”
self.df=pd.read_csv(csv_file)
def len(self):
‘’’
返回df的长度
‘’’
return len(self.df)
def getitem(self, idx):
‘’’
根据 idx 返回一行数据
‘’’
return self.df.iloc[idx].SalePrice
- 至此,我们的数据集已经定义完成了,我们可以实例化一个对象来访问
ds_demo= BulldozerDataset(‘median_benchmark.csv’)
- 我们可以直接使用如下命令查看数据集数据
前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
- 使用索引可以直接访问对应的数据
ds_demo[0]
自定义的数据集已经创建好了,下面我们使用官方提供的数据载入器,读取数据
1.2 DataLoader
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、shuffle(是否进行shuffle操作)、num_workers(加载数据时使用几个子进程)。下面做一个简单的演示:
dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)
DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
idata=iter(dl)
print(next(idata))
常见的用法是使用for循环对其进行遍历
for i, data in enumerate(dl):
print(i,data)
为了节约空间,这里只循环一遍
break
至此,我们已经可以通过dataset定义数据集,并使用D