MindSpore提供了回调Callback机制,可以在训练过程中执行自定义逻辑,这里以使用框架提供的ModelCheckpoint为例。
-
ModelCheckpoint可以保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
# 设置模型保存参数
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
-
通过MindSpore提供的model.train接口可以方便地进行网络的训练,LossMonitor可以监控训练过程中loss值的变化。
# 导入模型训练需要的库
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model
-
def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, rep