import torch
import torch.utils.data
from torch.utils.data import Dataset
from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler
import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *
transform = transforms.Compose([
torchvision.transforms.ToTensor()
])
'''
深度学习数据集准备主要是 3 步
1. 创建数据集类
2. 实例化 数据集
3. 创建dataloader
其中, 创建dataloader中涉及到一些常用策略, 主要是sampler的选择。设置好sampler之后,扔进dataloader的构造函数中。当然sampler可以用默认的。
torch.utils.data.Sampler 采样器的基类,继承自它的类必须提供__iter__(), __len__()方法,用来通过数据集元素的索引来迭代数据集。
torch.utils.data.SequentialSampler 总是以相同的顺序来迭代数据集的所有元素
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None) :
指定 replacement[重复放置索引号], 可以重复采用,就一次采32个可重复的 randint来取inde,共采 num_samples 个。num_samples // len(dataset)+ 1 这么多次
不指定replacement, 通过randperm 一次取出来所有的随机后的index。
torch.utils.data.SubsetRandomSampler(indices) 从索引中随机 randperm采样,不重复。
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True) 根据概率采样num_samples个样本。
torch.utils.data.BatchSampler(sampler, batch_size, drop_last) 包裹另一个采样器来产生 mini-batch。
torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0)
torch.utils.data.Subset(
full_dataset, indices=range(0, len(full_dataset), slice))
case 1. WeightedRandomSampler 适用于类别不平衡来进行按概率平衡采样. iter为iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
weights为index权重,权重越大的取到的概率越高
num_samples: 生成的采样长度,也就是最终经过采用后的数据集长度,也就是单次采样的次数
replacement:是否为有放回取样
multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}
case 2. SubsetRandomSampler 根据index从数据集中抽取这些index对应的图片,然后随机排序. iter为 (self.indices[i] for i in torch.randperm(len(self.indices)))
所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样。比读进来数据集,再划分数据集快一些。
torch.randperm对数组随机排序
indices为给定的下标数组
case 3.
'''
class MyDataset(Dataset):
def __init__(self, data_length):
self.data_length = data_length
self.data = list(range(self.data_length))
self.label = [str(item) for item in self.data]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.label[index]
trainset = MyDataset(10)
#testset = MyDataset(8)
'''
#测试 WeightedRandomSampler
weights = [2 if data < 5 else 1 for data, label in trainset]
# num_samples 一意思是这个数据集一共的采用长度
sampler = WeightedRandomSampler(weights, num_samples=25, replacement=True)
train_loader = DataLoader(trainset, batch_size=1, sampler=sampler)
'''
# 测试
split_num = int(len(trainset) * 0.5)
index_list = list(range(len(trainset)))
train_idx, test_idx = index_list[:split_num], index_list[split_num:]
train_sampler = sampler.SubsetRandomSampler(train_idx)
test_sampler = sampler.SubsetRandomSampler(test_idx)
train_loader = DataLoader(trainset, batch_size=6,
sampler=train_sampler)
loader_val = DataLoader(test_sampler, batch_size=1,
sampler=test_sampler)
small = 0
big = 0
print(f'dataset length = {len(train_loader)}')
for data, label in train_loader:
print(f'label = {label}, data = {data}')
if int(label[0]) < 5:
small+=1
else:
big+=1
print(f"small num={small}, big num={big}")
pytorch中的sampler
最新推荐文章于 2024-04-18 20:11:14 发布