tensorflow如何保存best_model

本文分享了一位新手在使用TensorFlow时如何解决保存最佳模型的问题。作者提出了在训练结束后,通过重新测试每个模型在测试集上的效果来选择最佳模型。尽管这种方法不涉及在训练过程中实时监测,但其简单易实现。文章详细描述了训练和测试阶段的代码实现,包括如何在一个会话中使用不同模型的参数,并通过循环遍历所有模型进行测试。最后,作者总结了这个不优雅但可行的解决方案,期待读者的反馈和建议以改进方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


背景:

在最近的实验中,用到了tensorflow做为我的模型框架。对于我而言,早已听说tensorflow的“赫赫威名”,但实际上我还属于tensorflow的一个新手,因此在实际应用中也遇到了许多问题。今天总结一下,在使用tensorflow框架时,我是如何解决保存最佳model这个问题。



问题描述:

对于保存最佳模型,我们有两种思路:

  • 在每个epoch过后,计算改epoch所得到的模型在test数据上的结果。

                优点: 不需要将所有epoch的model都保存下来

  • 在training任务完成后,在重新测试这些模型在test数据上的结果,选出最佳模型。

                优点:简单好实现   缺点:需要保存所有模型,比较占内存

在代码中,training和非training是两种状态。

def build_graph(self, is_training = True)

而is_training这个标志代表了数据的两种处理方式。

mean, var = tf.cond(tf.cast(is_training, tf.bool),
								mean_var_with_update,
								lambda: (ema.average(batch_mean), ema.average(batch_var)))



解决方案:

对于tensorflow的工作流程,我的理解如下:

 

 我看了一些博客和教程,他们介绍了tensorflow可以同时构建多个sess,也可以在同一个sess中运行多图,但是目前这些方法我还不会,因此只能另辟蹊径。

回到正题,上面说了我们的training和非training是两种状态,因此training阶段和test阶段实际上用到的是两个不同的graph,但是目前我还不会在同一sess下运行多图......

怎么办?当然办法要比困难多啦!

那就用我们上面提到的第二种保存最佳model的方法,先进行training阶段,然后在test,选出最佳model。整个流程可以用下面几段代码解释。

if __name__ == '__main__':
			
	#开始训练
	train(model_dir,model_name,epoch,batch_size)
		
	tf.reset_default_graph()	
	test(model_dir,best_result_dir,beat_test_log,one_model,batch_size)

如何test:

那么到目前为止,我们的已经选出了如何去实现这一功能的方案,虽然不是最优,但它可以work。

 那么如何在training结束之后,依次去测试每个模型在test数据的效果呢?

这个时候我们在来想一想tensorflow它的test机制。tensorflow有两种加载已保存模型的方法,简单来说有两种。参考

  • 自己构建图,加载保存好模型的参数。
  • 不用自己构建图,而是加载模型中的图,加载保存好模型的参数。

两种方法都行,反正就构建好图,在加载模型参数罢了。那么如果我只想运行一次,然后就把所有已保存模型都检验一次呢?上面说过,tensorflow的机制是在sess中运行graph,我们目前只会一个sess运行一个graph,已保存的模型肯定不止一个,怎么办?

 

这时候我们就要充分发挥我们的主观能动性,虽然说一个sess下只能有一个图,但是我可以在一个图中用不同的参数呀。深度学习不就是得到网络和参数这两个东西吗(这是我目前很浅薄的个人理解,如果有误希望大家指出来,感谢!)?反正这些模型,它们的graph或者说网络都一样,不同的是参数,那么我就只要定义一次graph就可以了,剩下的换不同模型的参数即可。

	#获取所有模型的名称
	ckpt = tf.train.get_checkpoint_state(model_dir)
	all_models  = ckpt.all_model_checkpoint_paths
	print(all_models)

	#GPU settings 90% memory usage
	config = tf.ConfigProto() 
	config.gpu_options.per_process_gpu_memory_fraction = 1.0
	config.gpu_options.allow_growth = True	

	with tf.Session(config = config) as sess:    
		
		model = Dense(sess, 'test',batch_size)	
		for one_model in all_models:
			print("Now is test {}th Epochs----------------".format(one_model.split('.')[0].split('_')[-1]))
			
			# Load data
			data = DataFetcher('test', batch_size = batch_size)
			data.setDaemon(True)
			data.start()

			test(model_dir,best_result_dir,beat_test_log,one_model,batch_size)

总结:

这个解决方案,是个不优雅的处理方式,但是我目前唯一想到的idea。做个总结,希望能对一些朋友有帮助。当然,也希望大家多多批评,让这个idea能更上一层楼。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值