使用datasets.ImageFolder()划分数据集并打乱顺序(简单易懂)

一、代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random

path = 
transforms=
proportion=0.1 #测试集比例
batch_size=32

data = datasets.ImageFolder(path,transforms)
n = len(data)  #数据集总数
n_test = random.sample(range(1, n), int(proportion * n))  #按比例取随机数列表

test_set = torch.utils.data.Subset(data, n_test)  #按照随机数列表取测试集
train_set = torch.utils.data.Subset(data,list(set(range(1, n)).difference(set(n_test))))  #测试集剩下作为训练集

data_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
data_test=DataLoader(test_set, batch_size=batch_size, shuffle=False)

#输出筛选的训练集labels
for batch_idex, (data, targets) in enumerate(data_test):
    print(batch_idex,targets)

二、测试结果

用了十类的图片数据集测试,结果数据集成功被打乱了!

在这里插入图片描述

三、后记

网上其它的代码只进行划分忽略了打乱这个环节,那可能有人会问DataLoader里不是有shuffle吗,为什么不用呢?

  • 因为是先划分的数据集,如果数据集的标签是连续排列的,划分的数据集的标签会出现扎堆现象,后续再在DataLoader时打乱就没效果啦。就像下面这样,测试集将0,1标签都取走了而没有其它标签,这显然不是一个合理的数据集!

在这里插入图片描述

list取补集代码:list(set(range(1, n)).difference(set(n_test)))

  • 取完补集最后需要转成list,不然会报错:TypeError: ‘set’ object is not subscriptable
  • 8
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
`datasets.ImageFolder`是PyTorch中用于加载图像数据集的一个类。它可以根据文件夹和文件夹中的图像文件来创建数据集,并且可以自动地将图像数据进行预处理和标准化。 使用`datasets.ImageFolder`可以方便地加载和处理图像数据集。下面是一个示例代码,展示了如何使用`datasets.ImageFolder`加载一个数据集: ```python import torch from torchvision import datasets, transforms # 数据预处理和标准化 data_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder(root='path/to/train/data', transform=data_transform) val_dataset = datasets.ImageFolder(root='path/to/val/data', transform=data_transform) test_dataset = datasets.ImageFolder(root='path/to/test/data', transform=data_transform) # 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) ``` 在上面的示例代码中,我们首先定义了一个`data_transform`,用于对图像数据进行预处理和标准化。然后,我们使用`datasets.ImageFolder`类分别加载了训练集、验证集和测试集,并将`data_transform`应用到每个数据集中的所有图像上。最后,我们使用PyTorch的`DataLoader`类创建了数据加载器,用于在训练、验证和测试模型时加载数据集
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值