B站 刘二大人 传送门 加载数据集
本节用的还是糖尿病数据集,老师放的百度云课件中有数据集压缩包,自行下载。
这一节把数据加载做成了类,并且增加了批量处理。代码中有一些测试和绘图后注释掉的内容,这里就不删除了。
链接:https://pan.baidu.com/s/1vZ27gKp8Pl-qICn_p2PaSw
提取码:cxe4
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) # 'delimiter'为分隔符
print('xy = \n', xy.shape)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self