PyTorch深度学习实践——8 加载数据集(课堂记录)

UP:B站-刘二大人

原视频链接:08.加载数据集_哔哩哔哩_bilibili

'''
随机梯度下降可以有效解决鞍点问题,但是每次训练一个样本,没办法利用GPU或者CPU的并行计算能力,所以计算速度慢,计算时间长
整个Batch一起进行梯度下降计算,虽然计算速度很快,但是没办法有效解决鞍点问题,性能会差一些
所以在这里提出mini-batch来均衡性能和时间的需求
'''

'''
使用mini-batch训练时候是嵌套循环
for epoch in range(training_epochs): # 外层表示训练的周期
    for i in range(total_batch): # 内层对batch进行迭代
'''

# epoch:所有样本进行一次训练是一个epoch;
# batch—size:每次训练时候用的样本数量;
# Iteration: 样本总数/batch-size

'''
1,获取一个batch数据的步骤
(假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m)
1,首先我们要确定数据集的长度n。 
结果类似:n = 1000。
2,然后我们从0到n-1的范围中抽样出m个数(batch大小)。
假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]
3,接着我们从数据集中去取这m个数对应下标的元素。
拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我们将结果整理成两个张量作为输出。
拿到的结果是两个张量,类似batch = (features,labels),
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。
第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader的 sampler和batch_sampler参数指定的。
sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。
batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认(False)方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。
第3个步骤的核心逻辑根据下标取数据集中的元素,是由 Dataset的 __getitem__方法实现的。
第4个步骤的逻辑由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。
'''

import torch
from torch.utils.data import Dataset # Dataset是一个抽象类,不可以直接实例化,需要自己定义子类去继承Dataset
from torch.utils.data import DataLoader # DataLoader是一个可以进行实例化的类
import numpy as np
import matplotlib.pyplot as plt


class DiabetesDataset(Dataset): # DiabetesDataset作为子类去继承Dataset
    '''
    __init__()是初始化函数,之后我们可以提供数据集路径进行数据的加载
    __getitem__()帮助我们通过索引找到某个样本
    __len__()帮助我们返回数据集大小
    '''
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=",", dtype=np.float32)
        self.len = xy.shape[0] # 糖尿病这个数据集是N行,8列指标加1列评价共9列,xy.shape就是[N, 9],所以这步操作实际是得到样本的数量N
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index): # 魔法函数(具体定义可以百度)可以通过下标索引取出数据
        return self.x_data[index], self.y_data[index] # python中return x, y 返回的是(x, y)这样一个元组

    def __len__(self): # 使用len()函数可以把数据的长度返回
        return self.len

dataset = DiabetesDataset("diabetes.csv")
# 四个参数:第一个是数据集, 第二个是每个mini-batch的容量, 第三个是是否需要打乱,第四个num_workers在读入数据时是不是要多线程读取,设置成2就是双线并行
train_loader = DataLoader(dataset = dataset, batch_size=32, shuffle=True) # DataLoader每次会读出一个batch_size

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.active = torch.nn.ReLU()

    def forward(self, x):
        x = self.active(self.linear1(x))
        x = self.active(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

criterion = torch.nn.BCELoss(reduction = "mean")

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


loss_list = []
epoch_list = []

# 使用多线程时候可能会遇到报错,解决报错的方法就是把训练过程封装到函数里面,即:加下面这行代码
if __name__ == '__main__':
    for epoch in range(100):
        loss_minibatch_sum = 0
        for i, (inputs, labels) in enumerate(train_loader, 0):
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print("epoch:", epoch, "loss:", loss.data.item())

            loss_minibatch_sum += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss_list.append(loss_minibatch_sum/train_loader.batch_size)
        loss_minibatch_sum = 0
        epoch_list.append(epoch)  # 绘图用


plt.plot(epoch_list, loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值