完成数据预处理、网络定义、损失函数和优化器定义之后,就可以进行模型训练了。模型训练包含两层迭代,数据集的多轮迭代(epoch)和一轮数据集内按分组(batch)大小进行的单步迭代。其中,单步迭代指的是按分组从数据集中抽取数据,输入到网络中计算得到损失函数,然后通过优化器计算和更新训练参数的梯度。
-
为了简化训练过程,MindSpore封装了Model高阶接口。用户输入网络、损失函数和优化器完成Model的初始化,然后调用train接口进行训练,train接口参数包括迭代次数(epoch)和数据集(dataset)。
模型保存是对训练参数进行持久化的过程。Model类中通过回调函数(callback)的方式进行模型保存,如下面代码所示。用户通过CheckpointConfig设置回调函数的参数,其中,save_checkpoint_steps指每经过固定的单步迭代次数保存一次模型,keep_checkpoint_max指最多保存的模型个数。
'''
network, loss, optimizer are defined before.
batch_num, epoch_size are training parameters.
'''
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# CheckPoint CallBack definition
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=