pytorch下dataset和dataloader极简实践(包括自带图片)

数据类

数据集主要是

torch.utils.data类

要实现加载和预处理数据可分为以下两个步骤:

1.加载数据集(Dateset)

1.1 自带数据集(Mnist/FashionMnist等)

加载时需要完成数据格式的转换(transform).

一种加载方法是用自带的数据集,来自torchvision大类:


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=True,
    transform=transform)
testset = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=False,
    transform=transform)

1.2 自备图片

若要实现自有文件图片,需要实现一个继承torch.utils.data.Dataset的类.这里dataset有两种实现方式:

  • map-style(类似数组)
    需要实现两个数组函数__getitem__()和__len__()。
  • Iterable-style(类似指针)
    这里需要实现迭代函数__iter()__。

下例实现map-style()函数,在该函数中可以通过索引把图像数据转换,返回为tensor数据.

import torch.utils.data as data
class DatasetFromFolder(data.Dataset):
    def __init__(self):
        super().__init__()
        self.path = 'data/pose'#指定自己的路径
        self.image_filenames = [x for x in listdir(self.path)]
    def __getitem__(self, index):
        a = Image.open(join(self.path, self.image_filenames[index])).convert('L')
        a = a.resize((64, 64), Image.BICUBIC)
        a = transforms.ToTensor()(a)
        return a
    def __len__(self):
        return len(self.image_filenames)

2.预处理数据

就是加载数据,这里需要定义一个DataLoader类并设置必要参数,如一批数据batch的数量,是否随机,

pose = DatasetFromFolder()
train_loader = torch.utils.data.DataLoader(
     dataset=pose,
     batch_size=25,
     shuffle=False,
     num_workers=0,
     pin_memory=True,#用Nvidia GPU时生效
     drop_last=True
 )

3.测试

通过迭代train_loader类,来每次输出一个batch,如:

 for i, x in enumerate(train_loader):
     print(i)
     print(x.shape)
     #torchvision.utils.save_image(x, './pose-img/%d.jpg'%(i), nrow=5)

4.其他

如果在使用datasets.ImageFolder(path)时,出现 'Found 0 files in subfolders of: xxx’这个错误,还是乖乖用继承上述类实现加载自身数据吧。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,下面是使用PyTorch中的DatasetDataloader的示例程序: 首先,我们需要导入相关的库: ```python import torch from torch.utils.data import Dataset, DataLoader ``` 接下来我们定义一个数据集类,这个类需要继承自`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。这里我们假设我们有一个包含10个样本的数据集,每个样本由一个图片和一个标签组成: ```python class MyDataset(Dataset): def __init__(self): self.data = [(torch.randn(3, 224, 224), i) for i in range(10)] def __len__(self): return len(self.data) def __getitem__(self, index): x, y = self.data[index] return x, y ``` 接下来,我们可以使用`MyDataset`类创建一个数据集对象,并使用`DataLoader`类创建一个数据加载器对象。在创建`DataLoader`对象时,我们需要指定`batch_size`和`shuffle`参数: ```python my_dataset = MyDataset() my_dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True) ``` 现在我们可以使用`my_dataloader`迭代数据集中的样本了,每个迭代器返回一个包含`batch_size`个样本的元组,其中第一个元素是一个大小为`(batch_size, 3, 224, 224)`的张量,代表`batch_size`个图片,第二个元素是一个大小为`(batch_size,)`的张量,代表`batch_size`个标签。我们可以使用下面的代码来迭代数据集: ```python for x, y in my_dataloader: print(x.shape, y.shape) ``` 输出结果如下: ``` torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) torch.Size([2, 3, 224, 224]) torch.Size([2]) ``` 这个程序演示了如何使用PyTorch中的DatasetDataloader来加载数据集,并迭代数据集中的样本。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值