如何保存MindSpore模型

本文介绍三种MindSpore权重模型保存方法

模型权重直接保存

MindSpore的权重文件称为checkpoint,以ckpt为文件后缀,使用MindSpore提供的统一接save_checkpoint对模型进行保存,代码中model为事先训练好的模型。

import mindspore

ckpt_file_path = './output/best_model_19.ckpt' # 保存路径

# 模型额外相关信息保存,如epoch, batch_size, 
append_info = dict()
append_info['batch_size'] = batch_size
append_info['version'] = mindspore.__version__
append_info['epoch'] = epoch

mindspore.save_checkpoint(model,
                          ckpt_file_path,
                          append_dict=append_info)

在训练时进行保存

MindSpore可以指定在训练结束后自动保存模型文件,需要在callbacks里指定保存的路径:
CheckPoint配置策略:
MindSpore有两种保存CheckPoint策略:迭代策略和时间策略,可以通过创建CheckpointConfig对象设置相应策略。 CheckpointConfig中有四个参数可以自定义设置:

  1. save_checkpoint_steps:表示每隔多少个step保存一个CheckPoint文件,默认值为1。
  2. save_checkpoint_seconds:表示每隔多少秒保存一个CheckPoint文件,默认值为0。
  3. keep_checkpoint_max:表示最多保存多少个CheckPoint文件,默认值为5。
  4. keep_checkpoint_per_n_minutes:表示每隔多少分钟保留一个CheckPoint文件,默认值为0。

用法1:

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

config_ck = CheckpointConfig(save_checkpoint_steps=5,
                             keep_checkpoint_max=10)
 
ckpoint_cb = ModelCheckpoint(prefix='resnet50',
                             directory=save_path,
                             config=config_ck)
model.train(epoch_num,
            dataset,
            callbacks=ckpoint_cb)

用法2:

kpt_file_path = './output/best_model_19.ckpt' # 保存路径

# 在训练时指定保存路径,训练结束后会对该模型文件进行保存
model.train(epochs,
            dataset_train,
            callbacks=[ValAccMonitor(
                model,
                dataset_val,
                epochs,
                ckpt_directory=ckpt_file_path)])

保存最优模型

此方法可以保证保存的模型权重为在验证集中表现最优的权重:

## 定义单步训练
train_one_step = nn.TrainOneStepCell(net_with_loss, optimizer)

for epoch in range(epochs):
    train_one_epoch(train_one_step, imdb_train, epoch)
    valid_loss = evaluate(net, imdb_valid, loss, epoch)
    
	# 判断在验证集的损失对比目前最小损失值有没有更小
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        # 最终保证
        save_checkpoint(net, ckpt_file_name)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值