自定义采样方式 Pytorch

Pytorch的默认采样方式为:不放回的随机采样。如何实现有放回采样或者只采样某些类别的样本呢?这时候,我们需要自定义采样方式了。实现也很简单,就是修改源码几行代码就ok。

采样源码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler, SubsetRandomSampler


class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify ``num_samples`` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        num_samples (int): number of samples to draw, default=len(dataset)
        replacement (bool): samples are drawn with replacement if ``True``, default=False
    """
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self.num_samples = num_samples
        if self.num_samples is None:
            self.num_samples = len(self.data_source)


    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())

    def __len__(self):
        return len(self.data_source)
   

data_source表示读取的数据([输入,标签],…),这里只使用它的长度信息;最主要就是_iter_函数,torch.randperm(n).tolist()返回[0,1,…,n-1]的随机排列。

自定义实现

自定义采样,就是获得我们想要样本的index,然后作为_iter_函数的输出就行。

  • 放回采样:[0,n-1]范围内的随机重采样
  • 过采样:在原有基础上,增加某个类别样本的index
  • 采样类别为k的样本:获得标签为k的样本index(这个可以由SubsetRandomSampler,不同自己去修改源码)
class Mydataset(Dataset):
    def __init__(self):
        self.data = torch.randperm(8)
        self.label = torch.randint(0,2,(8,))
        
    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]
        return index, label, data

    def __len__(self):
        return len(self.data)


def get_fixed_label_samples(dataset, label=0):
    labels = dataset.label
    return torch.where(labels == label)[0].tolist()

dataset = Mydataset()
ind = get_fixed_label_samples(dataset)
dl = DataLoader(
        dataset=dataset,
        # sampler=RandomSampler(data_source=dataset),
        sampler=SubsetRandomSampler(ind),
        batch_size=3,
        num_workers=2
)

for item in dl:
    print(item)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值