基于MindSpore的早停算法

转载地址:https://bbs.huaweicloud.com/forum/thread-136522-1-1.html

作者:李响

邮箱:chaojililin@163.com

分享个人项目中基于MindSpore的早停算法

防止训练过程过拟合以及epoch设置过大导致训练时长过长,利用早停算法来避免这些问题,本人基于MindSpore编写早停算法代码,可直接拿来用

建立lx_tool.py,代码如下:

class BestAccSaver(object):
    def __init__(self, FLAGS, network):
        """
        初始化函数
        """
        self.best_acc = 0.0
        self.FLAGS = FLAGS
        self.network = network
    
    def handle_acc_inf(self, val_acc):
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            os.makedirs(self.FLAGS.inner_output_path + "best_model", exist_ok=True)
            if os.path.exists(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt'):
                os.chmod(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt', stat.S_IWUSR)
                os.remove(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
            save_checkpoint(self.network, self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
            print("modify best acc:", self.best_acc)
        return
    
    def revert_to_best(self): #恢复模型参数
        param_dict = load_checkpoint(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
        load_param_into_net(self.network, param_dict)
        return

其中,FLAGS是在调用处设置的参数处理,比如:

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_dir',
        type=str,
        #default=r"D:\ATransforLearn\besttextclassifier\tensorflow\v1.0.0\train_code\traindata\bbc",
        default=r"D:\ATransforLearn\besttextclassifier\mindspore\v1.0.0\train_code\traindata\toutiao_2w.json",
        help='Path to folders of labeled text.'
    )
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=0.01,
        help='How large a learning rate to use when training.'
    )
    parser.add_argument(
        '--how_many_training_steps',
        type=int,
        default=500,
        help='How many training steps to run before ending.'
    )

    parser.add_argument(
        '--batch_size',
        type=int,
        default=100,
        help='training batch size.'
    )
    parser.add_argument(
        '--fix_length',
        type=int,
        default=-1,
        help='cut or add words to fix length.'
    )
    parser.add_argument(
        '--eval_step_interval',
        type=int,
        default=10,
        help='How often to evaluate the training results.'
    )
    parser.add_argument(
        '--testing_percentage',
        type=int,
        default=10,
        help='What percentage of images to use as a test set.'
    )
    parser.add_argument(
        '--validation_percentage',
        type=int,
        default=10,
        help='What percentage of images to use as a validation set.'
    )
    parser.add_argument(
        '--three_parts',
        default=False,
        type=tools.str2bool,
        help='if True, the dataset will be separated into three parts'
    )

    parser.add_argument(
        '--termination_patience',
        type=int,
        default=10,
        help='terminate after these times of bad performance'
    )
    parser.add_argument(
        '--use_early_termination',
        default=True,
        type=tools.str2bool,
        help='if True, use early termination strategy'
    )

    parser.add_argument(
        '--inner_input_path',
        type=str,
        default='./modelinput'
    )

    parser.add_argument(
        '--inner_output_path',
        type=str,
        default='./modeloutput'
    )

    parser.add_argument(
        '--current_path',
        default='./',
        type=str,
    )

    parser.add_argument(
        '--language',
        default="chinese",  # chinese, english
        type=str,
    )

    parser.add_argument(
        '--high_acc_mode',
        default=False,
        type=tools.str2bool,
    )
    return parser
###主函数###
parser = get_parser()
FLAGS, unparsed = parser.parse_known_args()

在调用处配置如下代码:

if "is_early_stop" in os.modelinf:
    monitor.best_acc_saver.revert_to_best()
    print("revert to best model")
monitor.best_acc_saver.revert_to_best()  # 就算是正常情况下是要恢复至最好的精度模型
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值