本文介绍三种MindSpore权重模型保存方法
模型权重直接保存
MindSpore的权重文件称为checkpoint,以ckpt为文件后缀,使用MindSpore提供的统一接save_checkpoint对模型进行保存,代码中model为事先训练好的模型。
import mindspore
ckpt_file_path = './output/best_model_19.ckpt' # 保存路径
# 模型额外相关信息保存,如epoch, batch_size,
append_info = dict()
append_info['batch_size'] = batch_size
append_info['version'] = mindspore.__version__
append_info['epoch'] = epoch
mindspore.save_checkpoint(model,
ckpt_file_path,
append_dict=append_info)
在训练时进行保存
MindSpore可以指定在训练结束后自动保存模型文件,需要在callbacks里指定保存的路径:
CheckPoint配置策略:
MindSpore有两种保存CheckPoint策略:迭代策略和时间策略,可以通过创建CheckpointConfig对象设置相应策略。 CheckpointConfig中有四个参数可以自定义设置:
- save_checkpoint_steps:表示每隔多少个step保存一个CheckPoint文件,默认值为1。
- save_checkpoint_seconds:表示每隔多少秒保存一个CheckPoint文件,默认值为0。
- keep_checkpoint_max:表示最多保存多少个CheckPoint文件,默认值为5。
- keep_checkpoint_per_n_minutes:表示每隔多少分钟保留一个CheckPoint文件,默认值为0。
用法1:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
config_ck = CheckpointConfig(save_checkpoint_steps=5,
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='resnet50',
directory=save_path,
config=config_ck)
model.train(epoch_num,
dataset,
callbacks=ckpoint_cb)
用法2:
kpt_file_path = './output/best_model_19.ckpt' # 保存路径
# 在训练时指定保存路径,训练结束后会对该模型文件进行保存
model.train(epochs,
dataset_train,
callbacks=[ValAccMonitor(
model,
dataset_val,
epochs,
ckpt_directory=ckpt_file_path)])
保存最优模型
此方法可以保证保存的模型权重为在验证集中表现最优的权重:
## 定义单步训练
train_one_step = nn.TrainOneStepCell(net_with_loss, optimizer)
for epoch in range(epochs):
train_one_epoch(train_one_step, imdb_train, epoch)
valid_loss = evaluate(net, imdb_valid, loss, epoch)
# 判断在验证集的损失对比目前最小损失值有没有更小
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
# 最终保证
save_checkpoint(net, ckpt_file_name)