Pytorch中数据采样方法Sampler(torch.utils.data)(二) —— WeightedRandomSampler & SubsetRandomSampler

WeightedRandomSampler加权随机采样

平衡不平衡数据的抽取

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

其中__iter__为:

iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

其中

  • weights为index权重,权重越大的取到的概率越高
  • num_samples: 生成的采样长度
  • replacement:是否为有放回取样
  • multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}

如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍

import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)

## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
weights = [2 if label == 1 else 1 for data, label in trainset]
sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
dataloader = DataLoader(trainset, batch_size=16, sampler=sampler)

SubsetRandomSampler索引随机采样

根据index从数据集中抽取这些index对应的图片,然后随机排序

torch.utils.data.SubsetRandomSampler(indices)

其中__iter__为:

(self.indices[i] for i in torch.randperm(len(self.indices)))

其中

  • torch.randperm对数组随机排序
  • indices为给定的下标数组

所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
 

如果我要划分train_set和test_set, 那么读进整个数据集来再split比较慢

不如我直接生成train_set的index和test_set的index这样就可以很快了,所以就出现了SubsetRandomSampler

import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
 
testset = torchvision.datasets.MNIST(
    root='dataset/',
    train=False,
    download=True,
    transform=transform
)

split_num = int(len(trainset) * 0.8)
index_list = list(range(len(trainset)))
train_idx, val_idx = index_list[:split_num], index_list[split_num:]

train_sampler = sampler.SubsetRandomSampler(train_idx)
val_sampler = sampler.SubsetRandomSampler(val_idx)

loader_train = DataLoader(trainset, batch_size=100,
                          sampler=train_sampler)

loader_val = DataLoader(trainset, batch_size=100,
                        sampler=val_sampler)

  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
SubsetRandomSampler是PyTorch的一个采样器类,用于在给定数据集的子集上进行随机采样。它可以用于创建mini-batch训练数据集和验证数据集。 在使用SubsetRandomSampler时,你需要提供一个索引列表,该列表表示数据的样本索引。然后SubsetRandomSampler将根据这些索引随机选择样本。 下面是一个使用SubsetRandomSampler的示例: ```python import torch from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler # 创建自定义数据集 class CustomDataset(Dataset): def __init__(self): self.data = [1, 2, 3, 4, 5] def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) # 创建数据集实例 dataset = CustomDataset() # 设定训练集索引 train_indices = [0, 1, 2] # 设定验证集索引 val_indices = [3, 4] # 创建采样器实例 train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) # 创建数据加载器 train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler) val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler) # 在训练集上进行迭代 for batch in train_loader: print(batch) # 在验证集上进行迭代 for batch in val_loader: print(batch) ``` 在上面的示例,我们创建了一个自定义的数据集CustomDataset,并使用SubsetRandomSampler将数据集划分为训练集和验证集。然后,我们可以使用DataLoader加载数据集,并通过迭代器访问数据集的mini-batch。 希望这个例子能帮助你理解SubsetRandomSampler的用法。如果还有其他问题,请继续提问。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值