转载地址:https://bbs.huaweicloud.com/forum/thread-136522-1-1.html
作者:李响
邮箱:chaojililin@163.com
分享个人项目中基于MindSpore的早停算法
防止训练过程过拟合以及epoch设置过大导致训练时长过长,利用早停算法来避免这些问题,本人基于MindSpore编写早停算法代码,可直接拿来用
建立lx_tool.py,代码如下:
class BestAccSaver(object): def __init__(self, FLAGS, network): """ 初始化函数 """ self.best_acc = 0.0 self.FLAGS = FLAGS self.network = network def handle_acc_inf(self, val_acc): if val_acc > self.best_acc: self.best_acc = val_acc os.makedirs(self.FLAGS.inner_output_path + "best_model", exist_ok=True) if os.path.exists(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt'): os.chmod(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt', stat.S_IWUSR) os.remove(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt') save_checkpoint(self.network, self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt') print("modify best acc:", self.best_acc) return def revert_to_best(self): #恢复模型参数 param_dict = load_checkpoint(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt') load_param_into_net(self.network, param_dict) return
其中,FLAGS是在调用处设置的参数处理,比如:
def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( '--data_dir', type=str, #default=r"D:\ATransforLearn\besttextclassifier\tensorflow\v1.0.0\train_code\traindata\bbc", default=r"D:\ATransforLearn\besttextclassifier\mindspore\v1.0.0\train_code\traindata\toutiao_2w.json", help='Path to folders of labeled text.' ) parser.add_argument( '--learning_rate', type=float, default=0.01, help='How large a learning rate to use when training.' ) parser.add_argument( '--how_many_training_steps', type=int, default=500, help='How many training steps to run before ending.' ) parser.add_argument( '--batch_size', type=int, default=100, help='training batch size.' ) parser.add_argument( '--fix_length', type=int, default=-1, help='cut or add words to fix length.' ) parser.add_argument( '--eval_step_interval', type=int, default=10, help='How often to evaluate the training results.' ) parser.add_argument( '--testing_percentage', type=int, default=10, help='What percentage of images to use as a test set.' ) parser.add_argument( '--validation_percentage', type=int, default=10, help='What percentage of images to use as a validation set.' ) parser.add_argument( '--three_parts', default=False, type=tools.str2bool, help='if True, the dataset will be separated into three parts' ) parser.add_argument( '--termination_patience', type=int, default=10, help='terminate after these times of bad performance' ) parser.add_argument( '--use_early_termination', default=True, type=tools.str2bool, help='if True, use early termination strategy' ) parser.add_argument( '--inner_input_path', type=str, default='./modelinput' ) parser.add_argument( '--inner_output_path', type=str, default='./modeloutput' ) parser.add_argument( '--current_path', default='./', type=str, ) parser.add_argument( '--language', default="chinese", # chinese, english type=str, ) parser.add_argument( '--high_acc_mode', default=False, type=tools.str2bool, ) return parser
###主函数### parser = get_parser() FLAGS, unparsed = parser.parse_known_args()
在调用处配置如下代码:
if "is_early_stop" in os.modelinf: monitor.best_acc_saver.revert_to_best() print("revert to best model") monitor.best_acc_saver.revert_to_best() # 就算是正常情况下是要恢复至最好的精度模型