背景:
在最近的实验中,用到了tensorflow做为我的模型框架。对于我而言,早已听说tensorflow的“赫赫威名”,但实际上我还属于tensorflow的一个新手,因此在实际应用中也遇到了许多问题。今天总结一下,在使用tensorflow框架时,我是如何解决保存最佳model这个问题。
问题描述:
对于保存最佳模型,我们有两种思路:
- 在每个epoch过后,计算改epoch所得到的模型在test数据上的结果。
优点: 不需要将所有epoch的model都保存下来
- 在training任务完成后,在重新测试这些模型在test数据上的结果,选出最佳模型。
优点:简单好实现 缺点:需要保存所有模型,比较占内存
在代码中,training和非training是两种状态。
def build_graph(self, is_training = True)
而is_training这个标志代表了数据的两种处理方式。
mean, var = tf.cond(tf.cast(is_training, tf.bool),
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
解决方案:
对于tensorflow的工作流程,我的理解如下:
我看了一些博客和教程,他们介绍了tensorflow可以同时构建多个sess,也可以在同一个sess中运行多图,但是目前这些方法我还不会,因此只能另辟蹊径。
回到正题,上面说了我们的training和非training是两种状态,因此training阶段和test阶段实际上用到的是两个不同的graph,但是目前我还不会在同一sess下运行多图......
怎么办?当然办法要比困难多啦!
那就用我们上面提到的第二种保存最佳model的方法,先进行training阶段,然后在test,选出最佳model。整个流程可以用下面几段代码解释。
if __name__ == '__main__':
#开始训练
train(model_dir,model_name,epoch,batch_size)
tf.reset_default_graph()
test(model_dir,best_result_dir,beat_test_log,one_model,batch_size)
如何test:
那么到目前为止,我们的已经选出了如何去实现这一功能的方案,虽然不是最优,但它可以work。
那么如何在training结束之后,依次去测试每个模型在test数据的效果呢?
这个时候我们在来想一想tensorflow它的test机制。tensorflow有两种加载已保存模型的方法,简单来说有两种。参考
- 自己构建图,加载保存好模型的参数。
- 不用自己构建图,而是加载模型中的图,加载保存好模型的参数。
两种方法都行,反正就构建好图,在加载模型参数罢了。那么如果我只想运行一次,然后就把所有已保存模型都检验一次呢?上面说过,tensorflow的机制是在sess中运行graph,我们目前只会一个sess运行一个graph,已保存的模型肯定不止一个,怎么办?
这时候我们就要充分发挥我们的主观能动性,虽然说一个sess下只能有一个图,但是我可以在一个图中用不同的参数呀。深度学习不就是得到网络和参数这两个东西吗(这是我目前很浅薄的个人理解,如果有误希望大家指出来,感谢!)?反正这些模型,它们的graph或者说网络都一样,不同的是参数,那么我就只要定义一次graph就可以了,剩下的换不同模型的参数即可。
#获取所有模型的名称
ckpt = tf.train.get_checkpoint_state(model_dir)
all_models = ckpt.all_model_checkpoint_paths
print(all_models)
#GPU settings 90% memory usage
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 1.0
config.gpu_options.allow_growth = True
with tf.Session(config = config) as sess:
model = Dense(sess, 'test',batch_size)
for one_model in all_models:
print("Now is test {}th Epochs----------------".format(one_model.split('.')[0].split('_')[-1]))
# Load data
data = DataFetcher('test', batch_size = batch_size)
data.setDaemon(True)
data.start()
test(model_dir,best_result_dir,beat_test_log,one_model,batch_size)
总结:
这个解决方案,是个不优雅的处理方式,但是我目前唯一想到的idea。做个总结,希望能对一些朋友有帮助。当然,也希望大家多多批评,让这个idea能更上一层楼。