Pytorch Dataset入门

Pytorch Dataset code:torch/utils/data/dataset.py#L17

Pytorch Dataset tutorial: tutorials/beginner/basics/data_tutorial.html


理论:

PyTorch中的Dataset是一个抽象类,用来表示数据集的接口,所有其他数据集都需要继承这个类,并且覆写以下三个方法:

  1. __init__:初始化数据集的一些配置,例如加载所有的数据标签。
  2. __len__:以便len(dataset)可以返回数据集的大小,例如n。如果n小于数据集长度,则只会取前n个的数据。
  3. __getitem__:输入是数据的索引,以便可以使用dataset[i]来获取第i个样本,数据增强一般会在这里做。

代码:

下面是一个自定义的Dataset样例(不可执行):

import cv2
import json
import torch.utils.Dataset as Dataset

class CustomDataset(Dataset):
    def __init__(self, imgs_path, labels_path, img_transform=None, label_transform=None):
        self.imgs_path = imgs_path  # 输入图像的路径,list
        self.labels_path = labels_path  # 输入图像对应的标签路径,list
        self.img_transform = img_transform  # 图像的数据增强
        self.label_transform = label_transform  # 标签的数据增强

    def __len__(self):
        return len(self.imgs_path)  # 返回数据集的长度

    def __getitem__(self, idx):
        img_path = self.imgs_path[idx]
        label_path = self.labels_path[idx]
        img = cv2.imread(img_path)  # 读取图像
        label = json.load(open(label_path))  # 读取标签
        if self.img_transform:  # 图像的数据增强
            img = self.img_transform(img)
        if self.label_transform:  # 标签的数据增强
            label = self.label_transform(label)
        return img, label  # 返回图像和标签,用于训练

总结:

值得注意的是,Dataset只负责数据的加载和预处理,对于如何训练数据(例如:是否进行shuffle,是否进行并行加速等)这部分的逻辑是由DataLoader实现的。通常情况下,我们会将Dataset和DataLoader一起使用。

另外,PyTorch还提供了一些常用的数据集,如:ImageFolder,CIFAR10,MNIST等,这些数据集都是继承Dataset类,同时在init方法中进行数据的下载,以及在getitem方法中进行数据的加载和预处理。

Dataset是单线程读取数据,每次只能读取一个样本,不能一次性读取一个mini-batch的数据。

Dataset的主要特性包含: 

  • 抽象接口:PyTorch通过定义一个抽象Dataset类,让用户可以使用统一的方式来加载各种不同的数据,提供了很好的扩展性。

  • 懒加载:实际的数据载入并不发生在构造数据集实例时,而是发生在用到这些数据时,这样可以提高内存利用率,并且可以实现对大规模数据的处理。

  • 预处理:Dataset的一个重要应用就是数据预处理,你可以在getitem函数中进行任何你的数据预处理过程。


嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
好的,下面是一个简单的PyTorch入门代码,实现了一个三层全连接神经网络,用于MNIST手写数字分类任务。 ``` import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # 定义超参数 batch_size = 64 learning_rate = 0.01 num_epochs = 10 # 数据预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 加载数据集 train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 定义模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 784) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x model = Net() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 训练模型 for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item())) # 测试模型 model.eval() with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total)) ``` 以上代码实现了一个三层全连接神经网络,使用MNIST数据集进行训练和测试。首先定义了超参数,然后进行数据预处理和加载数据集。接着定义了一个三层全连接神经网络模型,使用交叉熵损失函数和随机梯度下降优化器进行训练。在训练过程,对模型进行了前向传播和反向传播,最后测试了模型的准确率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值