在将数据输入网络前,一般需要通过DataSet和DataLoader两个类将数据进行整理,但是默认的DataSet类对于无标签数据的支持并不友好,通过查资料发现,可以自己重写DataSet来解决这个问题
class CustomDataSet(Dataset):
def __init__(self, path): #这一步主要是来读取数据,也可在这一步将x,y划分
self.data = pd.read_csv(path, index_col=0)
self.data = np.array(self.data)
def __getitem__(self, index): #这一步主要是返回x,y,若数据中无y则删去和y相关的代码
x_data = self.data[index][:-1]
y_data = self.data[index][-1]
x_data = torch.tensor(x_data,dtype=torch.float32).unsqueeze(0)
y_data = torch.tensor(y_data)
return x_data,y_data
def __len__(self): #返回样本个数,不是特征个数,这个返回结果直接影响index的值
return self.data.shape[0]
通过上述的改写,我们在调用CustomDataSet时只需填写path就可以了,然后将结果放入DataLoader中即可