通过继承DataSet抽象类定义自己的数据集,使用DataLoader将数据集变为一个可迭代对象
1. 继承DataSet,自定义一个数据集(:
需要继承DataSet类,并且实现两个成员方法:
- getitem_() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
- len_() 该方法返回数据集的总长度
eg: 实例化一个对象ds_demo,通过ds_demo[index]方法得到index对应的数据值,通过len(ds_demo)获取数据总长度
#引用
from torch.utils.data import Dataset
import pandas as pd
#定义一个数据集
class BulldozerDataset(Dataset):
""" 数据集演示 """
def __init__(self, csv_file):
"""实现初始化方法,在初始化的时候将数据读载入"""
self.df=pd.read_csv(csv_file)
def __len__(self):
'''
返回df的长度
'''
return len(self