最近在训练一个检测器,由于训练数据不足因此需要做数据增强,那么我这边写了代码去做数据增强(这部分将会在下一篇进行介绍),其中使用到了tensorflow会话获取数据,可是问题出现了!gtx 1080ti显卡11G的内存,总共才处理六万张图像,但是运行速度越来越慢,眼看吃完晚饭散步回来几个小时了还没见有处理到一半,这就让我不得不探索个究竟并期待解决这类问题了!
此篇文章,我们只谈思路不谈代码,但会给一个基本的代码框架以便更好说明原因!
首先,使用如下命令2S间隔监测一次gtx 1080ti的使用情况!
watch -n 2 nvidia-smi
得到的信息情况如下:
内存已经差不多吃完了,没办法,tensorflow运行起来就是这么霸道,GPU显存有多少就基本吃多少,当然我们可以调整其使用内存量,但在这里不做讨论。再仔细看看,是不是发现什么了,没错啦,中部右侧的0%显示GPU使用率为0,what the fuck,什么情况啊,难道不是用GPU来计算处理的吗?可是代码跑起来确实有打印如下使用GPU信息的,而且驱动也没装错啊,之前都有正确运行了的!
那什么原因呢?我也不知道,但是总得想个办法来看看到底卡在哪里了,到底是读进图像时卡住了,还是tensorflow某个操作耗时太久了,还是?然后,我想到了个方法,使用datetime.datetime.now()获取时间,并打印某一个操作前后的时间差,哈哈,总算得到一些有效信息了!
看看一开始的打印信息:
起初打印出来的总耗时还是蛮低的,大概也就0.01~0.04s的范围,且分析到在23节点打印出来的耗时是最大的,而23节点是打印sess.run()处理前后的间隔耗时!
那么运行十来分钟后,此时的时间打印信息如下:
这时候打印出来的总耗时已经增大到0.7s左右了,足足增大了15~70倍的耗时时间,而这才运行了十来分钟而已!那么我们看到23节点打印出来的耗时此时也是最大的,而且整个耗时的增大也基本来自这个23节点即sess.run()产生的耗时!
问题直指sess.run()随着时间的拉长其运行速度越来越慢!
那么我尝试到google上搜索sess.run()运行越来越慢的原因,有找到如下类似问题:
这里提到如下在某一个循环里,不断建立tensorflow图节点再运行的话,会导致tensorflow运行越来越慢,有问题的代码结构大概长这样子:
for step in range(total_step):
tfops = tf add Ops ...
sess.run(tf.ops)
看到github上的分析,突然豁然开朗,tensorflow都是符号型结构的,它是在运行之前先建立好一张图并确认好张量的流向,再在迭代中不断喂数据进行训练的,如果我们在循环里不断的添加节点就导致tensorflow耗时在维护图结构上了。
github提供的解决思路是在sess.run()之前建立好图再运行,可是,不巧,我要做的事情,就是要运行动态的tensorflow图,即每一次运行的图结构都可能不一样,并非是固定图结构,且图像size不一,也没办法进行placeholder放置管道,这可没有任何答案告诉我怎么办啊!
怎么办呢?
我一开始调了代码结构,比如将多个需要session运行的操作放在同一个session里运行,可是实验反馈无效;另外一个,我主动在运行完图后销毁会话即使用sess.close()可是运行起来还是越来越慢;再然后,我试验在每一次调用sess.run()之前,调用tf.graph()并tf.graph.as_default()或tf.graph.as_graph_def()以为可以每一次都重新建立一张图来运行并让tensorflow自己销毁掉之前的图,可是,好事多磨啊,最终还是不行啊!
不过,我思路也是清晰的,我就是想要动态建立一张新图并销毁掉旧的图,那么我一直研究tensorflow Ops.py的源码,里面有定义了tensorflow图接口,最后查到里面有reset和finalize的图接口,想必就是我要的接口了,马上按照如下代码框架进行试验!
for step in range(total_step):
tf.reset_default_graph()
with tf.Session() as sess:
tfops = tf add Ops ...
sess.run(tfops)
tf.get_default_graph().finalize()
哇塞,验证结果显示,代码跑得飞快,哗啦啦的一直飞速运行就没有卡过,真是“有心人,天不负”!此时再看gtx 1080ti显卡的使用率已经基本维持在3%左右了,每一次运行耗时都基本在0.03s左右,六万张图像最终处理时间也不过三十分钟左右,速度提升还是很大的!