08Dataset and DataLoader

本文介绍了如何在PyTorch中为结构化数据创建自定义Dataset和DataLoader。文中详细阐述了Dataset的初始化、获取项和长度的方法,并通过实例展示了数据加载器的使用,包括批量大小、样本顺序打乱和线程使用。此外,还提到了图像数据处理的区别,以及一个小实验来验证样本顺序的变化。
摘要由CSDN通过智能技术生成

本小结主要针对结构化数据给出了一种使用pytorch写出自己的Dataset和DataLoader,而针对图像数据,由于数据集可能很大,往往和结构化数据的写法有所区别。其主要区别在于图像数据在构建Dataset的时候,往往不是一次性读取所有图像到内存中,而是生成图像的路径。然后DataLoader每一次加载的是一个batch的图像。

回顾上一小节加载数据的代码

  • 这里一次训练的batch就是全部数据。而pytorch当中,我们一直强调读取的是mini-batch,这是因为对于大数据集而言,只有mini-batch才具备可行性。
    Alt

1,结构化数据的Dataset

  • 自己的Dataset,首先要继承pytorch的Dataset,这是一个抽象类,无法实例化。
  • 然后重写Dataset里面的三个魔法方法,分别是 _ _ i n i t _ _ , _ _ g e t i t e m _ _ , _ _ l e n _ _ . \_\_init\_\_,\_\_getitem\_\_,\_\_len\_\_. __init____getitem__,__len__.
    Alt

处理数据时有两种方法:读取所有数据,数据从__init__加载进来,都读到内存里面,然后每一次调用__getitem__方法的时通过index[i]索引,适合小数据集。
而对于大数据集,几十G,通常把文件名放在列表中,再调用__getitem__方法去文件中读取数据,这样能保证内存的高效使用。

2,创建数据加载器

直接使用DataLoader实例化一个数据迭代器。

  • DataLoader的几个参数解释:
  • 1,dataset就等于自己的datase
  • 2,batch_size一次加载多少个样本
  • 3,shuffle,打乱样本的顺序,即下一个epoch的同一个顺序的batch里面包含的样本是不一样的。如下图所示:
    Alt
    小实验
import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader


class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 加载数据集,这里采用的是将全部读取到内存里,用xy变量保存
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)

        self.len = xy.shape[0]  # xy是N行9列,shape是元组(N,9),shape[0]==N

        self.x_data = torch.from_numpy(xy[:, :-1])  # 特征数据矩阵

        self.y_data = torch.from_numpy(xy[:, [-1]])  # target矩阵

    def __getitem__(self, index):
        # getitem实例化对象支持下标操作
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        # 返回数据集的长度
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,
                          batch_size=4, shuffle=True, num_workers=0)
# 记录两次epoch每一个batch
record1 = []
record2 = []
for epoch in range(2):
    for data in train_loader:
        if epoch==0:
            record1.append(data)
        else:
            record2.append(data)
# 比较第一个batch的标签值即可以知道样本顺序是否被打乱
record1[0][1]==record2[0][1]

实验结果

  • 显然,样本顺序已经被打乱。
    Alt

  • 4,num_works,使用多少个线程跑代码,在Windows上只能使用主线程,即只能设置成0.

完整代码

import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader

# 1,数据准备
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 加载数据集,这里采用的是将全部读取到内存里,用xy变量保存
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)

        self.len = xy.shape[0]  # xy是N行9列,shape是元组(N,9),shape[0]==N

        self.x_data = torch.from_numpy(xy[:, :-1])  # 特征数据矩阵

        self.y_data = torch.from_numpy(xy[:, [-1]])  # target矩阵

    def __getitem__(self, index):
        # getitem实例化对象支持下标操作
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        # 返回数据集的长度
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,
                          batch_size=256, shuffle=True, num_workers=0)
# 2,设计模型
class Model(torch.nn.Module):

    def __init__(self):
        super(Model, self).__init__()

        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):

        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

# 3,构建损失函数与优化器
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4,训练周期
for epoch in range(5):

    # 循环对train_loader做迭代,enumerate返回迭代的数据和下标,获取迭代次数
    # 把从train_loader拿出来的(x,y)元组放到data里面
    for i, data in enumerate(train_loader):
        # 1. Prepare data
        inputs, labels = data

        # 2. Forward
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.item())

        # 3. Backward
        optimizer.zero_grad()
        loss.backward()

        # 4. Update
        optimizer.step()

实验结果

Alt

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值