问题描述:
首先,MindSpore的用户自定义数据集mindspore.dataset.GeneratorDataset
最多只能输出二元组(后续训练过程中Model.train
方法限定当loss_fn
为None
时使用一元组,不为None
时使用二元组),即data
和label
,如果我想要返回其他信息(如data
的来源文件),那么必须要么和data
一起输出,要么和label
一起输出。然而,这是不被允许的。因为mindspore.dataset.Generator
要求其输出必须是numpy.ndarray
二元组,如果要其他信息加入data
,也必须保证其结果是numpy.ndarray
才行,同理,加入label
亦然。这就导致mindspore.dataset.Generator
的数据控制能力远远低于PyTorch
。
明明封装层次更高,反而更难以使用,这是不可理解的。
我要做的是主题分割相关的,在复现一个论文模型时需要处理长短不一的句子。由于MindSpore并未提供类似于PyTorch中pack_padded_sequence
之类的rnn_utils
,所以导致我需要自己处理句子长度,但是现在句子长度憋死到mindspore.dataset.GeneratorDataset
类中了。按照API中的提醒,可以在Model.train
方法中设置loss_fn
为None
,然后传入自行实现的带损失的网络,这样可以处理多标签类数据,但是这样的方法变得麻烦了起来。
请问有什么比较友好的处理方法吗?
个人认为能优雅地处理mindspore.dataset.GeneratorDataset
至关重要,对它的source
输出格式采取过多限制(要求迭代的出来的对象必须是numpy.ndarray
)会使诸多简单问题变得复杂起来。
解答:
首先,非常感谢您认真的体验及建议。 1. 需要纠正的一点,GeneratorDataset 没有限定用户只能返回二元组,你可能是看了 “编程指南”中例子,被误导了,GeneratorDataset是可以支持用户自定义返回多个列的。如下面的例子,如果你还想更多列,那么在 __getitem__ 中返回 tuple 中元素的个数和 GeneratorDataset中参数column_names对应起来即可。
import numpy as np
import mindspore.dataset as ds
class GetDatasetGenerator:
def __init__(self):
np.random.seed(58)
self.__data = np.random.sample((5, 2))
self.__label = np.random.sample((5, 1))
self.__label2 = np.random.sample((5, 2, 3, 4))
self.__label3 = np.random.sample((5, 64))
self.__file_name = np.array([["1.jpg"], ["5.jpg"], ["3.jpg"], ["4.jpg"], ["5.jpg"],])
def __getitem__(self, index):
return (self.__data[index], self.__label[index], self.__label2[index], self.__label3[index], self.__file_name[index])
def __len__(self):
return len(self.__data)
dataset_generator = GetDatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label", "label2", "label3", "file_name"], shuffle=False)
for data in dataset.create_dict_iterator():
print(data)
2. 对于 pack_padded_sequence 类似的功能,你可以查看下如下接口的用法,它是数据集对象的一个方法,看是否能满足你的要求。 https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/dataset/mindspore.dataset.GeneratorDataset.html#mindspore.dataset.GeneratorDataset.bucket_batch_by_length 3. 关于GeneratorDataset的source必须返回是numpy对象的问题,我们最新版本中已经支持了用户返回int, float, str, bytes, np.ndarray这些基本类型,这些基本类型可以放在tuplek中返回,你可以尝试下,当然我们也在进一步思考支持更多类型的问题。 建议用户以numpy对象返回,主要考虑是因为我们的数据集预处理引擎是在C层运行的,用户返回的数据都需要在Python层->C层传递->Python层频繁传递,所以一个确定唯一的数据类型对于这样的传递是比较友好的。 欢迎收到你的进一步反馈。