Pytorch深度学习实践 第八讲 加载数据集

Epoch:所有样本都进行一次forward和backward。

Batch_size:训练样本中一次forward和backward的样本数。

Iteration:内迭代的次数,即训练集中有多少个batch_size,总样本数N/Batch_size。

Dataloader中的batch和shuffle:

示例代码:(还是diabetes.csv.gz数据集)

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
    def __getitem__(self, index):
        #可以按索引取数据
        return self.x_data[index],self.y_data[index] #返回的是(x_data,y_data)元组
    def __len__(self):
        #可以获得数据集长度
        return self.len
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)
        self.linear3 = torch.nn.Linear(4,1)
        self.sigmoid = torch.nn.Sigmoid()
    def forward(self,x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
model = Model()
filepath = 'diabetes.csv.gz'
dataset = DiabetesDataset(filepath)
train_loader = DataLoader(dataset=dataset,batch_size=22,shuffle=True,num_workers=2)
# print(train_loader)

critarion = torch.nn.BCELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #将模型参数作为优化算法的参数


if __name__ == '__main__':
    epoch_list = []
    loss_list = []
    for epoch in range(20):
        epoch_list.append(epoch)
        for i, data in enumerate(train_loader,0):
            #准备数据,enumerate()用来列出数据和数据下标
            input, labels = data #是张量
            #计算损失
            y_pred = model(input)
            loss = critarion(y_pred, labels)
            print(epoch,i,loss.item()) #epoch是迭代次数,i即iteration是内迭代次数,就是总样本数/batch_size
            loss_list.append(loss.item())
            #反馈
            optimizer.zero_grad()
            loss.backward()
            #更新
            optimizer.step()
    # plt.plot(epoch_list,loss_list)
    # plt.xlabel('epoch')
    # plt.ylabel('loss')
    # plt.show()

 说明:

1.通过Dataset类构造带有索引的数据集,它是一个抽象类,只能被其他子类继承,不能被实例化,要调用的话就要继承Dataset来产生一个可实例化的自定义的类,即代码中的class DiabetesDataset(Dataset),重写__getitem__()函数实现按索引取数据,返回的是(x_data,y_data)元组,重写__len__()函数获得数据集长度。

2.DataLoader是将数据集分成Mini-Batch的类。帮助加载数据集

3.num_workers:用几个线程来并行计算

4.enumerate()函数:列出数据及数据下标

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值