MindSpore中对于常见数据及已经有现成API来进行处理,常见数据集包括:CelebA、Cifar100、Cifar10、Coco、ImageNet、Minist、VOC
下面以Cifar10数据集作为例子展示一下接口调用及数据的图片的展示
以下为官网提供的API的接口调用及解释:
class
mindspore.dataset.
Cifar10Dataset
(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None)A source dataset for reading and parsing Cifar10 dataset. This api only supports parsing Cifar10 file in binary version now. #用于读取和解析 Cifar10 数据集的源数据集,这个api现在只支持解析二进制版本的Cifar10文件
The generated dataset has two columns
[image, label]
. The tensor of columnimage
is of the uint8 type. The tensor of columnlabel
is a scalar of the uint32 type # 生成的数据集有两列[image,label]。列图像的张量为uint8类型。列标签的张量是uint32类型的标量
接口调用及图片展示:
import mindspore.dataset as ds
from PIL import Image
import matplotlib.pyplot as plt
sampler = ds.SequentialSampler(num_samples=6)
dataset = ds.Cifar10Dataset(data_dir, sampler=sampler)
# 在数据集上创建迭代器,检索到的数据将是字典数据类型
for i, data in enumerate(dataset.create_dict_iterator()):
print("Image shape: {}".format(data['image'].shape), ", Label {}".format(data['label']))
image = data['image']
image = image.asnumpy() # mindspore.Tensor to numpy
image = Image.fromarray(image)
# plt
plt.subplot(2, 3, i + 1)
plt.imshow(image)
plt.title(f"{i + 1}", fontsize=6)
plt.xticks([])
plt.yticks([])
plt.show()
结果展示: