WeightedRandomSampler 重采样方法
import mindspore.dataset as ds
weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3] #weights貌似是要小于等于数据集的长度
dataset_path = "../../../mindspore_learning/MNIST_Data/train"
sampler = ds.WeightedRandomSampler(weights, 5) #5表示采样的数量,就是后面dataset的长度
# 定义数据集和采样器
dataset = ds.MnistDataset(dataset_path, sampler=sampler)
print(len(dataset))
有时候数据集中的不同样本对模型的训练贡献是不同的,因此可以使用加权采样来调整样本的抽样概率。WeightedRandomSampler
就是这样一种采样器,它根据提供的权重列表来决定每个样本被抽样的概率。
比如数据集有1万条数据,然后这个weights的长度也是1万,sampler = ds.WeightedRandomSampler(weights, 10000),dataset = ds.MnistDataset(dataset_path, sampler=sampler),最终dataset得到了这10000条数据, 但是和数据集的数据是不一样的
在mindone代码里的使用方法
ds=de.GeneratorDataset(dataset,column_names=dataset_column_names,sampler=de.WeightedRandomSampler([0.9,0.01,0.4,0.8,0.9,0.01,0.4,0.8],8),num_parallel_workers=min(32, num_parallel_workers),python_multiprocessing=python_multiprocessing)
#下面注释掉的是之前的代码
# ds = de.GeneratorDataset(
# dataset,
# column_names=dataset_column_names,
# num_parallel_workers=min(32, num_parallel_workers),
# shuffle=shuffle,
# python_multiprocessing=python_multiprocessing,
# )
当权重列表的长度与数据集大小不一致时,WeightedRandomSampler
会按照权重列表的长度进行取余操作,以确保在每个 epoch 内能够完整覆盖数据集。具体来说,如果权重列表的长度小于数据集大小,那么这个权重列表将在整个数据集上循环使用。
例如,如果您的数据集大小为1000,而权重列表的长度为8,那么 WeightedRandomSampler
在每个 epoch 中将按照如下顺序使用这8个权重来对数据集进行采样。