《昇思25天学习打卡营第04天|数据集Dataset》

数据集

  • 环境准备
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt

数据集加载

  • Mnist数据集为例
  • mindspore.dateset提供的接口仅支持解压后的数据文件,使用download库下载数据集并解压。
# 从开源库下载数据集
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)

数据集迭代

def visualize(dataset):
    figure = plt.figure(figsize=(4, 4)) #定义图片大小
    cols, rows = 3, 3 #定义迭代生成图排列顺序

    plt.subplots_adjust(wspace=0.5, hspace=0.5) # 调整子图之间的宽度和高度

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()): # 迭代过程
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(int(label))
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break
    plt.show()

visualize(train_dataset)

数据集常用操作

  • shuffle
    • 消除数据排列造成的分布不均问题
    train_dataset = train_dataset.shuffle(buffer_size=64)
    
  • map
    • 针对数据集指定列添加数据变换,将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集
      image, label = next(train_dataset.create_tuple_iterator())
      #(28, 28, 1) UInt8
      train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
      image, label = next(train_dataset.create_tuple_iterator())
      #Output (28, 28, 1) Float32
      
  • batch - 将数据集打包成batch是在有限硬件资源下使用梯度下降进行模型优化的折中办法,可以保证梯度下降的随机性和优化计算量
    train_dataset = train_dataset.batch(batch_size=32)
    image, label = next(train_dataset.create_tuple_iterator())
    # Output (32, 28, 28, 1) Float32
    

自定义数据集

可随机访问数据集

  • 通过索引/键直接访问对应位置的数据样本
    # 可随机访问数据集
    class RandomAccessDataset:
        def __init__(self):
            self._data = np.ones((5, 2))
            self._label = np.zeros((5, 1))
    
        def __getitem__(self, index):
            return self._data[index], self._label[index]
    
        def __len__(self):
            return len(self._data)
    

可迭代数据集

  • 通过迭代方式逐步获取数据样本。适用于随机访问成本不太高或者不可行的方案。
    #迭代器作为输入源
    class IterableDataset():
        def __init__(self, start, end):
            '''init the class object to hold the data'''
            self.start = start
            self.end = end
        def __next__(self):
            '''iter one data and return'''
            return next(self.data)
        def __iter__(self):
            '''reset the iter'''
            self.data = iter(range(self.start, self.end))
            return self
    

生成器

  • 直接依赖Python的生成器类型generator返回数据,直至生成器抛出异常
    def my_generator(start, end):
        for i in range(start, end):
            yield i
    
  • 25
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值