一、代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import random
path =
transforms=
proportion=0.1 #测试集比例
batch_size=32
data = datasets.ImageFolder(path,transforms)
n = len(data) #数据集总数
n_test = random.sample(range(1, n), int(proportion * n)) #按比例取随机数列表
test_set = torch.utils.data.Subset(data, n_test) #按照随机数列表取测试集
train_set = torch.utils.data.Subset(data,list(set(range(1, n)).difference(set(n_test)))) #测试集剩下作为训练集
data_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
data_test=DataLoader(test_set, batch_size=batch_size, shuffle=False)
#输出筛选的训练集labels
for batch_idex, (data, targets) in enumerate(data_test):
print(batch_idex,targets)
二、测试结果
用了十类的图片数据集测试,结果数据集成功被打乱了!
三、后记
网上其它的代码只进行划分忽略了打乱这个环节,那可能有人会问DataLoader里不是有shuffle吗,为什么不用呢?
- 因为是先划分的数据集,如果数据集的标签是连续排列的,划分的数据集的标签会出现扎堆现象,后续再在DataLoader时打乱就没效果啦。就像下面这样,测试集将0,1标签都取走了而没有其它标签,这显然不是一个合理的数据集!
list取补集代码:list(set(range(1, n)).difference(set(n_test)))
- 取完补集最后需要转成list,不然会报错:TypeError: ‘set’ object is not subscriptable