pytorch中的sampler

import torch
import torch.utils.data
from torch.utils.data import Dataset
from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler


import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *
 
transform = transforms.Compose([
    torchvision.transforms.ToTensor()
])

'''
深度学习数据集准备主要是 3 步
1. 创建数据集类
2. 实例化 数据集
3. 创建dataloader
其中, 创建dataloader中涉及到一些常用策略, 主要是sampler的选择。设置好sampler之后,扔进dataloader的构造函数中。当然sampler可以用默认的。


torch.utils.data.Sampler 采样器的基类,继承自它的类必须提供__iter__(), __len__()方法,用来通过数据集元素的索引来迭代数据集。

torch.utils.data.SequentialSampler 总是以相同的顺序来迭代数据集的所有元素

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None) : 
指定  replacement[重复放置索引号], 可以重复采用,就一次采32个可重复的 randint来取inde,共采 num_samples 个。num_samples // len(dataset)+ 1 这么多次
不指定replacement, 通过randperm 一次取出来所有的随机后的index。

torch.utils.data.SubsetRandomSampler(indices) 从索引中随机 randperm采样,不重复。

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True) 根据概率采样num_samples个样本。


torch.utils.data.BatchSampler(sampler, batch_size, drop_last) 包裹另一个采样器来产生 mini-batch。

torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0) 
torch.utils.data.Subset(
      full_dataset, indices=range(0, len(full_dataset), slice))

case 1. WeightedRandomSampler 适用于类别不平衡来进行按概率平衡采样. iter为iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
        weights为index权重,权重越大的取到的概率越高
        num_samples: 生成的采样长度,也就是最终经过采用后的数据集长度,也就是单次采样的次数
        replacement:是否为有放回取样
        multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}

case 2. SubsetRandomSampler 根据index从数据集中抽取这些index对应的图片,然后随机排序. iter为 (self.indices[i] for i in torch.randperm(len(self.indices)))
所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样。比读进来数据集,再划分数据集快一些。
        torch.randperm对数组随机排序
        indices为给定的下标数组
case 3. 



'''
class MyDataset(Dataset):
    def __init__(self, data_length):
        self.data_length = data_length
        self.data = list(range(self.data_length))
        self.label = [str(item) for item in self.data]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index], self.label[index]
    
trainset = MyDataset(10)
#testset = MyDataset(8)
'''
#测试 WeightedRandomSampler

weights = [2 if data < 5 else 1 for data, label in trainset]
# num_samples 一意思是这个数据集一共的采用长度
sampler = WeightedRandomSampler(weights, num_samples=25, replacement=True)
train_loader = DataLoader(trainset, batch_size=1, sampler=sampler)
'''

# 测试

split_num = int(len(trainset) * 0.5)
index_list = list(range(len(trainset)))
train_idx, test_idx = index_list[:split_num], index_list[split_num:]


train_sampler = sampler.SubsetRandomSampler(train_idx)
test_sampler = sampler.SubsetRandomSampler(test_idx)
 
train_loader = DataLoader(trainset, batch_size=6,
                          sampler=train_sampler)
 
loader_val = DataLoader(test_sampler, batch_size=1,
                        sampler=test_sampler)

small = 0
big = 0
print(f'dataset length = {len(train_loader)}')
for data, label in train_loader:
    print(f'label = {label}, data = {data}')
    if int(label[0]) < 5:
        small+=1
    else:
        big+=1
print(f"small num={small}, big num={big}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值