问题描述:
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op], input_columns="image", num_parallel_workers=num_parallel_workers)
# 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(count=repeat_size)
return mnist_ds
中间的map映射不大懂。API文档看了也不大明白,求助!
还有头文件的意义能说明吗?谢谢
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
解答:
您好,map的作用是对指定的列(input_columns)做数据增强操作(operations)。
mindspore的数据增强是pipeline的模式,这就意味着:
mnist_ds = ds.MnistDataset(data_path)
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op], input_columns="image", num_parallel_workers=num_parallel_workers)
这几行代码只是构建了一个dataset对象,当数据整个pipeline被launch时,流水线中会按照以下顺序执行:
MnistDataset-->输出两列数据{"image": image_data, "label": label_data}
-->第一个map对label列做type cast操作-->输出两列数据{"image": image_data(保持不变), "label": label_data_new(类型因为做typecast发生了变化)}
-->第二个map对image做了一系列操作-->输出两列数据{"image": image_data_new(做数据增强而发生了变化), "label": label_data_new(保持不变)
mindspore内部数据增强执行顺序实际上是一个算子一个算子按序执行,只是每个算子都是多线程,所有构建了mindspore高性能的pipeline模式,这里补充一点,如果您对每一个算子的输出感兴趣,可以在对应算子之后加一行代码:
for item in mnist_ds:
print(item)
这行代码您可以放在MnistDataset之后,也可以放在map之后,放的位置决定了数据pipeline中将有多少个算子在执行,您也可以依此检查每行代码的输出结果。
您的第二个问题是以下import的意义,我这边一一解读:
import mindspore.dataset as ds
表示您可以使用mindspore所有的dataset对象,包括MnistDataset,map、batch等。
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
c_transforms表示是mindspore内部实现的性能较高的C++算子,这些算子作为map的operations参数传入,然后对某一列做数据增强操作:
transforms.c_transforms主要是通用数值操作,包括typecast类型转换,onehot等等
vision.c_transforms则是图像的增强操作,底层实现是opencv
from mindspore.dataset.vision import Inter
前面提到vision是图像操作,Inter用于指定resize操作中插值算法,可选的有NEAREST,LINEAR,CUBIC,AREA,PILCUBIC
from mindspore import dtype as mstype
mindspore提供了一套数据类型对象,即dtype,mindspore所有转换类型的操作,指定的类型均为dtype。type_cast_op = C.TypeCast(mstype.int32)就是指定的输出类型是int32类型。