MindSpore数据集读取教程,如何读取自定义数据集

预备知识

预备知识,需要了解python 可迭代对象迭代器

可迭代对象:是指能够一次返回其中一个成员的对象,通常使用for循环 来完成此操作,如字符串、列表、元组、集合、字典等等之类的对象都属于可迭代对象。实现__getitem__(self, idx)函数的class就是一个可迭代对象。

MindSpore实现自定义数据集加载需要两个步骤:

假设当前需要读取一个CSV文件,格式如下:第一列为文本,第二列为标签,两列的标题分别为review,label。
在这里插入图片描述

步骤1:定义一个可迭代对象

实现__getitem__(self, idx)函数的class就是一个可迭代对象,下面对象通过_load(self)函数读取数据集,再通过__getitem__(self, idx)函数返回数据集。

class MyCSVData():
    """
    文本CSV数据集加载器
    加载数据集并处理为一个Python迭代对象。

    """

    def __init__(self, path):
        self.path = path
        self.review, self.label = [], []
        self._load()

    def _load(self):
        # 根据self.path load 需要读的csv文件
        with open(self.path, "r") as csv_file:
            dict_reader = csv.DictReader(csv_file)

            # 按行读取
            for row in dict_reader:
                review = row['review']
                label = int(np.float32(row['label']))
                # 数据处理
                label_onehot = [0] * 5
                label_onehot[label - 1] = 1
                review = re.split(' |,', review.lower())
                # 数据加载到List
                self.review.append(review)
                self.label.append(label_onehot)

    def __getitem__(self, idx):
        """
        定义可迭代对象返回当前结果的逻辑
        """
        return self.review[idx], self.label[idx]

    def __len__(self):
        """
        返回可迭代对象的长度
        :return: int
        """
        return len(self.review)

步骤2: 加载至GeneratorDataset(用于生成MindSpore数据对象)

GeneratorDataset会返回一个Dataset对象,该对象在后续训练中可以用于MindSpore训练及预测等操作,如下代码表示基于MyCSVData生成一个Dataset对象。

import mindspore.dataset as dataset 

csv_path = "./dataset/test1.csv"
data_train = dataset.GeneratorDataset(MyCSVData(csv_path),
                                      column_names=["review", "label"])      

通过以上两个步骤,就可以将任意格式的数据集读取为MindSpore的Dataset对象。


打印数据

data = next(data_train.create_dict_iterator())
print(data["review"])
print(data["label"])

分割验证集

data_train, data_valid = data_train.split([0.7, 0.3])

设置Batch Size

BATCH_SIZE = 128
data_train = data_train.batch(BATCH_SIZE, drop_remainder=True)
data_valid = data_valid.batch(BATCH_SIZE, drop_remainder=True)

训练

经过如上处理的数据可以直接基于MindSpore进行训练

model.train(num_epochs,
                data_train,
                callbacks=[
                    ValAccMonitor(model, data_valid, num_epochs, ckpt_directory='./model/')
                ])
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值