之前用过sklearn提供的划分数据集的函数,觉得超级方便。但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是“pytorch split dataset”之类的,但是搜出来还是没有我想要的。结果今天见鬼了突然看见了这么一个函数torch.utils.data.Subset。我的天,为什么超级开心hhhh。终于不用每次都手动划分数据集了。
torch.utils.data
Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler
torch的这个文件包含了一些关于数据集处理的类:
- class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
- class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
- cl