使用datasets.ImageFolder()划分数据集并打乱顺序(简单易懂)

这篇博客介绍了如何使用Python的torch库来划分和打乱数据集,以避免标签扎堆现象。首先,定义了数据路径、转换操作、测试集比例和批处理大小。接着,通过random.sample选择测试集样本,然后利用Subset创建训练集和测试集。使用DataLoader加载数据时,训练集启用shuffle,而测试集则不启用。博主指出,直接在DataLoader中打乱已经划分的数据集可能无效,因为数据集标签的连续性可能导致划分后的数据集中标签分布不均。通过这段代码,可以确保训练集和测试集的合理分布。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random

path = 
transforms=
proportion=0.1 #测试集比例
batch_size=32

data = datasets.ImageFolder(path,transforms)
n = len(data)  #数据集总数
n_test = random.sample(range(1, n), int(proportion * n))  #按比例取随机数列表

test_set = torch.utils.data.Subset(data, n_test)  #按照随机数列表取测试集
train_set = torch.utils.data.Subset(data,list(set(range(1, n)).difference(set(n_test))))  #测试集剩下作为训练集

data_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
data_test=DataLoader(test_set, batch_size=batch_size, shuffle=False)

#输出筛选的训练集labels
for batch_idex, (data, targets) in enumerate(data_test):
    print(batch_idex,targets)

二、测试结果

用了十类的图片数据集测试,结果数据集成功被打乱了!

在这里插入图片描述

三、后记

网上其它的代码只进行划分忽略了打乱这个环节,那可能有人会问DataLoader里不是有shuffle吗,为什么不用呢?

  • 因为是先划分的数据集,如果数据集的标签是连续排列的,划分的数据集的标签会出现扎堆现象,后续再在DataLoader时打乱就没效果啦。就像下面这样,测试集将0,1标签都取走了而没有其它标签,这显然不是一个合理的数据集!

在这里插入图片描述

list取补集代码:list(set(range(1, n)).difference(set(n_test)))

  • 取完补集最后需要转成list,不然会报错:TypeError: ‘set’ object is not subscriptable
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值