分享个人项目中基于MindSpore的早停算法
防止训练过程过拟合以及epoch设置过大导致训练时长过长,利用早停算法来避免这些问题,本人基于MindSpore编写早停算法代码,可直接拿来用
建立lx_tool.py,代码如下:
class BestAccSaver(object):
def __init__(self, FLAGS, network):
"""
初始化函数
"""
self.best_acc = 0.0
self.FLAGS = FLAGS
self.network = network
def handle_acc_inf(self, val_acc):
if val_acc > self.best_acc:
self.best_acc = val_acc
os.makedirs(self.FLAGS.inner_output_path + "best_model", exist_ok=True)
if os.path.exists(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt'):
os.chmod(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt', stat.S_IWUSR)
os.remove(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
save_checkpoint(self.network, self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
print("modify best acc:", self.best_acc)
return
def revert_to_best(self): #恢复模型参数
param_dict = load_checkpoint(self.FLAGS.inner_output_path + "best_model" + os.path.sep + 'best_acc.ckpt')
load_param_into_net(self.network, param_dict)
return
其中,FLAGS是在调用处设置的参数处理,比如:
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
type=str,
#default=r"D:\ATransforLearn\besttextclassifier\tensorflow\v1.0.0\train_code\traindata\bbc",
default=r"D:\ATransforLearn\besttextclassifier\mindspore\v1.0.0\train_code\traindata\toutiao_2w.json",
help='Path to folders of labeled text.'
)
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='How large a learning rate to use when training.'
)
parser.add_argument(
'--how_many_training_steps',
type=int,
default=500,
help='How many training steps to run before ending.'
)
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='training batch size.'
)
parser.add_argument(
'--fix_length',
type=int,
default=-1,
help='cut or add words to fix length.'
)
parser.add_argument(
'--eval_step_interval',
type=int,
default=10,
help='How often to evaluate the training results.'
)
parser.add_argument(
'--testing_percentage',
type=int,
default=10,
help='What percentage of images to use as a test set.'
)
parser.add_argument(
'--validation_percentage',
type=int,
default=10,
help='What percentage of images to use as a validation set.'
)
parser.add_argument(
'--three_parts',
default=False,
type=tools.str2bool,
help='if True, the dataset will be separated into three parts'
)
parser.add_argument(
'--termination_patience',
type=int,
default=10,
help='terminate after these times of bad performance'
)
parser.add_argument(
'--use_early_termination',
default=True,
type=tools.str2bool,
help='if True, use early termination strategy'
)
parser.add_argument(
'--inner_input_path',
type=str,
default='./modelinput'
)
parser.add_argument(
'--inner_output_path',
type=str,
default='./modeloutput'
)
parser.add_argument(
'--current_path',
default='./',
type=str,
)
parser.add_argument(
'--language',
default="chinese", # chinese, english
type=str,
)
parser.add_argument(
'--high_acc_mode',
default=False,
type=tools.str2bool,
)
return parser
###主函数###
parser = get_parser()
FLAGS, unparsed = parser.parse_known_args()
在调用处配置如下代码:
if "is_early_stop" in os.modelinf:
monitor.best_acc_saver.revert_to_best()
print("revert to best model")
monitor.best_acc_saver.revert_to_best() # 就算是正常情况下是要恢复至最好的精度模型