利用继承自Dataset的类,可以访问训练所需的数据
比如一下数据:保存为csv文件
from torch.utils.data import Dataset
import pandas as pd #这个包用来读取CSV数据
class mydataset(Dataset):
def __init__(self,csv_file): #self参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs)
self.csv_data=pd.read_csv(csv_file)
def __len__(self):
return len(self.csv_data)
def __getitem__(self,idx):
data=self.csv_data.values[idx]
return data
data=mydataset('/home/yls/Documents/test.csv')
print(data[3])
print(len(data))
输出结果如下:
[‘OpenSuse’ ‘stable’ ‘OpenSuse Repository’
‘zypper\xa0in\xa0python3-pandas’]
6