Pytorch的默认采样方式为:不放回的随机采样。如何实现有放回采样或者只采样某些类别的样本呢?这时候,我们需要自定义采样方式了。实现也很简单,就是修改源码几行代码就ok。
采样源码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
class RandomSampler(Sampler):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self.num_samples = num_samples
if self.num_samples is None:
self.num_samples = len(self.data_source)
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return len(self.data_source)
data_source
表示读取的数据([输入,标签],…),这里只使用它的长度信息;最主要就是_iter_
函数,torch.randperm(n).tolist()
返回[0,1,…,n-1]的随机排列。
自定义实现
自定义采样,就是获得我们想要样本的index,然后作为_iter_
函数的输出就行。
- 放回采样:[0,n-1]范围内的随机重采样
- 过采样:在原有基础上,增加某个类别样本的index
- 采样类别为k的样本:获得标签为k的样本index(这个可以由
SubsetRandomSampler
,不同自己去修改源码)
class Mydataset(Dataset):
def __init__(self):
self.data = torch.randperm(8)
self.label = torch.randint(0,2,(8,))
def __getitem__(self, index):
data = self.data[index]
label = self.label[index]
return index, label, data
def __len__(self):
return len(self.data)
def get_fixed_label_samples(dataset, label=0):
labels = dataset.label
return torch.where(labels == label)[0].tolist()
dataset = Mydataset()
ind = get_fixed_label_samples(dataset)
dl = DataLoader(
dataset=dataset,
# sampler=RandomSampler(data_source=dataset),
sampler=SubsetRandomSampler(ind),
batch_size=3,
num_workers=2
)
for item in dl:
print(item)