本教程包含MindSpore图像数据读取和一些基本数据集操作。
MindSpore 加载图像数据集
MindSpore加载图像数据集常用API:mindspore.dataset.ImageFolderDataset
数据集需要的分布格式为每个类别一个一个文件夹,文件夹内部结构如图,每个文件夹的名称会默认作为数据集的label:
加载
读取代码如下,执行结束后会生成一个Dataset对象:
import mindspore.dataset as ds
dataset_dir = "C:\\datasets\\caltech_for_user\\train"
dataset = ds.ImageFolderDataset(dataset_dir, decode=True)
MindSpore图像预处理
预处理
MindSpore1.8以前的版本:
图像预处理统一使用mindspore.dataset.vision.c_transforms模块,其中包含变形(Resize)、标准化(Normalize)、转置(HWC2HCW)等所有图像相关的预处理操作。
import mindspore.dataset.vision.c_transforms as c_transforms
image_size = 32
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
trans = [
c_transforms.Resize((image_size, image_size)),
c_transforms.Normalize(mean=mean, std=std),
c_transforms.HWC2CHW()
]
dataset = dataset.map(operations=trans, num_parallel_workers=1)
MindSpore1.8及以后的版本
图像预处理统一使用 mindspore.dataset.vision as vision
import mindspore.dataset.vision as vision
image_size = 32
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
trans = [
vision.Resize((image_size, image_size)),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
dataset = dataset.map(operations=trans, num_parallel_workers=1)
验证集分割
train, val = dataset.split([0.8, 0.2])
batch
如果需要使用mini-batch训练,需要使用如下代码对数据集进行处理:
batch_size = 128
train = train.batch(batch_size, drop_remainder=True)
打印图像数据集信息
MindSpore数据集使用create_dict_iterator()生成一个可迭代对象,然后使用next得到每一个样本,其中mindspore.dataset.ImageFolderDataset读取默认的图像的关键字为image,标签为label:
for i in range(5):
data = next(train.create_dict_iterator())
print(data['label'])
print(data['image'].shape)