Mindspore初学(二)
1.数据变换 Transforms
1.1加载数据
将原始数据转化为张量格式,方便数据可以直接用在神经网络模型中.
参数解释:
path
:数据集根目录位置。split
:训练、测试或推理数据集,可选train
,test
,infer
,默认参数是train
download
:是否下载数据集,当设置True
时,若数据集不存在可下载并解压数据集,默认为False
如加载数据cifar-10
数据集,本数据集包含训练数据和测试数据
from mindvision.dataset import Cifar10
#数据集根目录
data_dir="./datasets"
#下载解压并加载训练集
dataset=Cifar10(path=data_dir,split='train',batch_size=6,resize=32)
dataset=dataset.run()
1.2迭代数据
迭代访问数据,访问的数据类型默认为Tensor
;若设置output_numpy=True
,访问的数据类型为Numpy
。
data=next(dataset.create_dict_iterator())
data=next(datsset.create_dict_iterator(output_numpy=True))
1.3数据处理
参数解释:
shuffle
:是否打乱数据集的顺序,设置为True
时打乱数据集的顺序,默认为False
。
防止数据过拟合batch_size
:每组包含的数据个数,batch_size=2
设置每组包含2个数据,默认值为32。
加快数据运算repeat_num
:重复数据集的个数,repeat_num=1
即一份数据集,默认为1。
dataset=Cifar10(path=data_dir,split='train',batch_size=6,resize=32,repeat_num=1,shuffle=True)
1.4数据的增强
数据量过小会造成网络模型训练不起来,精度不达标。我们可以通过数据增强操作扩充样本的多样性,从而提升模型的泛化能力。
参数解释:
transform
:对数据集图像数据进行增强。batch_size
:对数据集标签数据进行处理
trans=[
transforms.RandomCrop((32,32),(4,4,4,4))#对图像进行自动裁剪
transforms.RandomHorizontalFlip(prob=0.5)#对图像进行随机水平翻转
transforms.HWC2CHW()#(h,w,c)转为为(c,h,w)
]
dataset=Cifar10(path=data_dir,split='train',batch_size=6,resize=32,transform=trans)
1.5Vision Transforms
mindspore.dataset.vision
模块提供一系列针对图像数据的Transforms。
比如:
Rescale
rescale
:缩放因子。
shift
:平移因子。
图像的每个像素将根据这两个参数进行调整,输出的像素值为outputi=inputi∗rescale+shift。
这里我们先使用numpy随机生成一个像素值在[0, 255]的图像,将其像素值进行缩放。
import numpy as np
random_np = np.random.randint(0, 255, (48, 48), np.uint8)
random_image = Image.fromarray(random_np)
print(random_np)
##输出
#[[103 163 149 ... 232 160 85]
#[ 17 54 82 ... 51 251 107]
#[133 88 21 ... 136 236 175]
#...
#[ 20 66 30 ... 116 19 247]
#[ 37 25 164 ... 128 23 113]
#[139 191 134 ... 239 83 133]]
Normalize:用于对输入图像的归一化,包括三个参数:
mean
:图像每个通道的均值。
std
:图像每个通道的标准差。
is_hwc
:输入图像格式为(height, width, channel)还是(channel, height, width)。
HWC2CWH:用于转换图像格式。在不同的硬件设备中可能会对(height, width, channel)或(channel, height, width)两种不同格式有针对性优化。MindSpore设置HWC为默认图像格式,在有CWH格式需求时,可使用该变换进行处理。