昇思MindSpore进阶教程-数据采样

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

数据采样

为满足训练需求,解决诸如数据集过大或样本类别分布不均等问题,MindSpore提供了多种不同用途的采样器(Sampler),帮助用户对数据集进行不同形式的采样。用户只需在加载数据集时传入采样器对象,即可实现数据的采样。

MindSpore目前提供了如RandomSampler、WeightedRandomSampler、SubsetRandomSampler等多种采样器。此外,用户也可以根据需要实现自定义的采样器类。

采样器

下面主要以CIFAR-10数据集为例,介绍几种常用MindSpore采样器的使用方法。
在这里插入图片描述

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"

path = download(url, "./", kind="tar.gz", replace=True)

解压后数据集文件的目录结构如下:

.
└── cifar-10-batches-bin
    ├── batches.meta.txt
    ├── data_batch_1.bin
    ├── data_batch_2.bin
    ├── data_batch_3.bin
    ├── data_batch_4.bin
    ├── data_batch_5.bin
    ├── readme.html
    └── test_batch.bin

RandomSampler

从索引序列中随机采样指定数目的数据。

下面的样例使用随机采样器,分别从数据集中有放回和无放回地随机采样5个数据,并打印展示。为了便于观察有放回与无放回的效果,这里自定义了一个数据量较小的数据集。

from mindspore.dataset import RandomSampler, NumpySlicesDataset

np_data = [1, 2, 3, 4, 5, 6, 7, 8]  # 数据集

# 定义有放回采样器,采样5条数据
sampler1 = RandomSampler(replacement=True, num_samples=5)
dataset1 = NumpySlicesDataset(np_data, column_names=["data"], sampler=sampler1)

print("With Replacement:    ", end='')
for data in dataset1.create_tuple_iterator(output_numpy=True):
    print(data[0], end=' ')

# 定义无放回采样器,采样5条数据
sampler2 = RandomSampler(replacement=False, num_samples=5)
dataset2 = NumpySlicesDataset(np_data, column_names=["data"], sampler=sampler2)

print("\nWithout Replacement: ", end='')
for data in dataset2.create_tuple_iterator(output_numpy=True):
    print(data[0], end=' ')

从上面的打印结果可以看出,使用有放回采样器时,同一条数据可能会被多次获取;使用无放回采样器时,同一条数据只能被获取一次。

WeightedRandomSampler

指定长度为N的采样概率列表,按照概率在前N个样本中随机采样指定数目的数据。

下面的样例使用带权随机采样器从CIFAR-10数据集的前10个样本中按概率获取6个样本,并展示已读取数据的形状和标签。

import math
import matplotlib.pyplot as plt
from mindspore.dataset import WeightedRandomSampler, Cifar10Dataset
%matplotlib inline

DATA_DIR = "./cifar-10-batches-bin/"

# 指定前10个样本的采样概率并进行采样
weights = [0.8, 0.5, 0, 0, 0, 0, 0, 0, 0, 0]
sampler = WeightedRandomSampler(weights, num_samples=6)
dataset = Cifar10Dataset(DATA_DIR, sampler=sampler)  # 加载数据

def plt_result(dataset, row):
    """显示采样结果"""
    num = 1
    for data in dataset.create_dict_iterator(output_numpy=True):
        print("Image shape:", data['image'].shape, ", Label:", data['label'])
        plt.subplot(row, math.ceil(dataset.get_dataset_size() / row), num)
        image = data['image']
        plt.imshow(image, interpolation="None")
        num += 1

plt_result(dataset, 2)

在这里插入图片描述
从上面的打印结果可以看出,本次在前面一共10个样本中随机采样了6条数据,只有前面两个采样概率不为0的样本才有机会被采样。

SubsetRandomSampler

从指定样本索引子序列中随机采样指定数目的样本数据。

下面的样例使用子序列随机采样器从CIFAR-10数据集的指定子序列中抽样3个样本,并展示已读取数据的形状和标签。

from mindspore.dataset import SubsetRandomSampler

# 指定样本索引序列
indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
sampler = SubsetRandomSampler(indices, num_samples=6)
# 加载数据
dataset = Cifar10Dataset(DATA_DIR, sampler=sampler)

plt_result(dataset, 2)

PKSampler

在指定的数据集类别P中,每种类别各采样K条数据。

下面的样例使用PK采样器从CIFAR-10数据集中每种类别抽样2个样本,最多10个样本,并展示已读取数据的形状和标签。

from mindspore.dataset import PKSampler

# 每种类别抽样2个样本,最多10个样本
sampler = PKSampler(num_val=2, class_column='label', num_samples=10)
dataset = Cifar10Dataset(DATA_DIR, sampler=sampler)

plt_result(dataset, 3)

DistributedSampler

在分布式训练中,对数据集分片进行采样。

下面的样例使用分布式采样器将构建的数据集分为4片,在分片抽取一个样本,共采样3个样本,并展示已读取的数据。

from mindspore.dataset import DistributedSampler

# 自定义数据集
data_source = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

# 构建的数据集分为4片,共采样3个数据样本
sampler = DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=3)
dataset = NumpySlicesDataset(data_source, column_names=["data"], sampler=sampler)

# 打印数据集
for data in dataset.create_dict_iterator():
    print(data)

自定义采样器

用户可以自定义采样器,并把它应用到数据集上。

iter 模式

用户可以继承Sampler基类,通过实现__iter__方法来自定义采样器的采样方式。

下面的样例定义了一个从下标0至下标9间隔为2采样的采样器,将其作用于自定义数据集,并展示已读取数据。

import mindspore.dataset as ds

# 自定义采样器
class MySampler(ds.Sampler):
    def __iter__(self):
        for i in range(0, 10, 2):
            yield i

# 自定义数据集
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']

# 加载数据
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
for data in dataset.create_tuple_iterator(output_numpy=True):
    print(data[0], end=' ')

getitem 模式

用户可以定义一个采样器类,该类包含 initgetitemlen 方法。

下面的样例定义了一个下标为 [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8] 的采样器类,将其作用于自定义数据集,并展示已读取数据。

import mindspore.dataset as ds

# 自定义采样器
class MySampler():
    def __init__(self):
        self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
    def __getitem__(self, index):
        return self.index_ids[index]
    def __len__(self):
        return len(self.index_ids)

# 自定义数据集
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']

# 加载数据
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
for data in dataset.create_tuple_iterator(output_numpy=True):
    print(data[0], end=' ')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

明志刘明

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值