刘二大人 《PyTorch深度学习实践》——第8讲 加载数据集(代码详解)
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset): # 定义一个糖尿病人数据集类,继承自Dataset
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) # 加载训练数据集
self.len = xy.shape[0] # shape本身是一个二元组(x,y)对应数据集的行数和列数,这里[0]我们取行数,即样本数
self.x_data = torch.from_numpy(xy[:, :-1]) # 取前八列 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要
self.y_data = torch.</