pytorch之自定义数据集
pytorch提供了torch.utils.data.Dataset这一抽象类来定义自己的数据集。使用时需要定义__len__ 和__getitem__这两个函数。
例如,定义一个最简单的数据集
class myDataset(Dataset):
def __init__(self, file_path):
self.csv_data = pandas.read_csv(file_path)
def __len__(self):
return len(self.csv_data)
def __getitem__(self, item):
return self.csv_data.loc[item]
另外,我们可以使用torch.utils.data.Dataloader来定义一个迭代器,从而实现取batch,shuffle和多线程。