这节主要是学习如何使用已有数据集来使用数据加载器
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:,-1:])
def __getitem__(self,index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('./data/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True, num_workers=2)