《昇思25天学习打卡营第3天 | mindspore DataSet 数据集的常见用法》

1. 背景:

使用 mindspore 学习神经网络,打卡第三天;

2. 训练的内容:

使用 mindspore 的常见的数据集 DataSet 的使用方法;

3. 常见的用法小节:

  • 数据集加载
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
  • 数据集迭代(create_tuple_iterator或create_dict_iterator 实现)
def visualize(dataset):
    figure = plt.figure(figsize=(4,4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx +1)
        plt.title(int(label))
        plt.axis('off')
        plt.imshow(image.asnumpy().squeeze(), cmap='gray')
        if idx == cols * rows - 1:
            break;
       
    plt.show()
  
visualize(train_dataset)
  • 数据集常用操作(shuffer, map, batch):
# shuffer - 随机打乱数据顺序
train_dataset = train_dataset.shuffle(buffer_size=64)
visualize(train_dataset)

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

# map - 对数据进行
# 将图像统一除以255,数据类型由uint8转为了float32
train_dataset = train_dataset.map(vision.Rescale(1.0/255.0, 0), input_columns='image')

#batch: 有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量
train_dataset = train_dataset.batch(batch_size=32)

# batch后的数据增加一维,大小为batch_size。
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)
  • 自定义数据集(可随机访问数据集/可迭代数据集/生成器类型)
# 自定义数据加载类,来生成数据集,通过 GeneratorDataset 接口实现数据加载
# 实现 __getitem__, __len__ 方法,进行 索引键直接访问

class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))
        
    def __getitem__(self, index):
        return self._data[index], self._label[index]
    
    def __len__(self):
        return len(self._data)
    

loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=['data', 'label'])

for data in dataset:
    print(data)

# 可跌代数据集,实现 __iter__, __next__ 方法
# 应用场景:iter(dataset),读取数据库,远程访问返回的数据流
class IterableDataset():
    def __init__(self, start, end):
        self.start = start
        self.end = end
        
    def __next__(self):
        return next(self.data)
    
    def __iter__(self):
        self.data = iter(range(self.start, self.end))
        return self


# 生成器:可迭代数据集类型,依赖 python 的 generator 返回数据
def my_generator(start, end):
    for i in range(start, end):
        yield i

dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=['data'])

for d in dataset:
    print(d)

活动参与链接:

https://xihe.mindspore.cn/events/mindspore-training-camp

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值