问题描述:
训练可以训练,测试时报这个错
问题出在了这行代码上
训练集和测试集格式也一样,求救
解答:
请问你的DatasetGenerator()是怎么定义的,使用GeneratorDataset()接口,需要使用你定义的DatasetGenerator()里的方式加载访问数据集,用法可以参考下面:
class MyAccessible:
def __init__(self):
self._data = np.random.sample((5, 2))
self._label = np.random.sample((5, 1))
def __getitem__(self, index):
return self._data[index], self._label[index]
def __len__(self):
return len(self._data)
dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"])
# list, dict, tuple of Python is also random accessible
dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"])
https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/dataset/mindspore.dataset.GeneratorDataset.html
https://mindspore.cn/docs/programming_guide/zh-CN/r1.5/pipeline_common.html